Skip to content
Snippets Groups Projects
model.py 20.00 KiB
from collections import deque

import numpy as np
import torch 
from torch import nn
import torch.nn.functional as F
from torchvision import models


class ImageCNN(nn.Module):
    """ 
    Encoder network for image input list.
    Args:
        c_dim (int): output dimension of the latent embedding
        normalize (bool): whether the input images should be normalized
    """

    def __init__(self, c_dim, normalize=True):
        super().__init__()
        self.normalize = normalize
        self.features = models.resnet34(pretrained=True)
        self.features.fc = nn.Sequential()

    def forward(self, inputs):
        c = 0
        for x in inputs:
            if self.normalize:
                x = normalize_imagenet(x)
            c += self.features(x)
        return c

def normalize_imagenet(x):
    """ Normalize input images according to ImageNet standards.
    Args:
        x (tensor): input images
    """
    x = x.clone()
    x[:, 0] = (x[:, 0] - 0.485) / 0.229
    x[:, 1] = (x[:, 1] - 0.456) / 0.224
    x[:, 2] = (x[:, 2] - 0.406) / 0.225
    return x


class LidarEncoder(nn.Module):
    """
    Encoder network for LiDAR input list
    Args:
        num_classes: output feature dimension
        in_channels: input channels
    """

    def __init__(self, num_classes=512, in_channels=2):
        super().__init__()

        self._model = models.resnet18()
        self._model.fc = nn.Sequential()
        _tmp = self._model.conv1
        self._model.conv1 = nn.Conv2d(in_channels, out_channels=_tmp.out_channels, 
            kernel_size=_tmp.kernel_size, stride=_tmp.stride, padding=_tmp.padding, bias=_tmp.bias)

    def forward(self, inputs):
        features = 0
        for lidar_data in inputs:
            lidar_feature = self._model(lidar_data)
            features += lidar_feature

        return features


class Encoder(nn.Module):
    """
    Multi-scale image + LiDAR fusion encoder using geometric feature projections
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.avgpool = nn.AdaptiveAvgPool2d((self.config.vert_anchors, self.config.horz_anchors))
        
        self.image_encoder = ImageCNN(512, normalize=True)
        self.lidar_encoder = LidarEncoder(num_classes=512, in_channels=2)

        self.image_conv1 = nn.Conv2d(64, config.n_embd, 1)
        self.image_conv2 = nn.Conv2d(128, config.n_embd, 1)
        self.image_conv3 = nn.Conv2d(256, config.n_embd, 1)
        self.image_conv4 = nn.Conv2d(512, config.n_embd, 1)
        self.image_deconv1 = nn.Conv2d(config.n_embd, 64, 1)
        self.image_deconv2 = nn.Conv2d(config.n_embd, 128, 1)
        self.image_deconv3 = nn.Conv2d(config.n_embd, 256, 1)
        self.image_deconv4 = nn.Conv2d(config.n_embd, 512, 1)

        self.lidar_conv1 = nn.Conv2d(64, config.n_embd, 1)
        self.lidar_conv2 = nn.Conv2d(128, config.n_embd, 1)
        self.lidar_conv3 = nn.Conv2d(256, config.n_embd, 1)
        self.lidar_conv4 = nn.Conv2d(512, config.n_embd, 1)
        self.lidar_deconv1 = nn.Conv2d(config.n_embd, 64, 1)
        self.lidar_deconv2 = nn.Conv2d(config.n_embd, 128, 1)
        self.lidar_deconv3 = nn.Conv2d(config.n_embd, 256, 1)
        self.lidar_deconv4 = nn.Conv2d(config.n_embd, 512, 1)

        self.image_projection1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.image_projection2 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.image_projection3 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.image_projection4 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.lidar_projection1 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.lidar_projection2 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.lidar_projection3 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        self.lidar_projection4 = nn.Sequential(nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, 512), nn.ReLU(True))
        
    def forward(self, image_list, lidar_list, velocity, bev_points, img_points):
        '''
        Image + LiDAR feature fusion using geometric projections
        Args:
            image_list (list): list of input images
            lidar_list (list): list of input LiDAR BEV
            target_point (tensor): goal location registered to ego-frame
            velocity (tensor): input velocity from speedometer
            bev_points (tensor): projected image pixels onto the BEV grid
            cam_points (tensor): projected LiDAR point cloud onto the image space
        '''
        if self.image_encoder.normalize:
            image_list = [normalize_imagenet(image_input) for image_input in image_list]

        bz, _, h, w = lidar_list[0].shape
        img_channel = image_list[0].shape[1]
        lidar_channel = lidar_list[0].shape[1]
        self.config.n_views = len(image_list) // self.config.seq_len

        image_tensor = torch.stack(image_list, dim=1).view(bz * self.config.n_views * self.config.seq_len, img_channel, h, w)
        lidar_tensor = torch.stack(lidar_list, dim=1).view(bz * self.config.seq_len, lidar_channel, h, w)

        image_features = self.image_encoder.features.conv1(image_tensor)
        image_features = self.image_encoder.features.bn1(image_features)
        image_features = self.image_encoder.features.relu(image_features)
        image_features = self.image_encoder.features.maxpool(image_features)
        lidar_features = self.lidar_encoder._model.conv1(lidar_tensor)
        lidar_features = self.lidar_encoder._model.bn1(lidar_features)
        lidar_features = self.lidar_encoder._model.relu(lidar_features)
        lidar_features = self.lidar_encoder._model.maxpool(lidar_features)

        image_features = self.image_encoder.features.layer1(image_features.contiguous())
        lidar_features = self.lidar_encoder._model.layer1(lidar_features.contiguous())
        if self.config.n_scale >= 4:
            # fusion at (B, 64, 64, 64)
            image_embd_layer1 = self.image_conv1(image_features)
            image_embd_layer1 = self.avgpool(image_embd_layer1)
            lidar_embd_layer1 = self.lidar_conv1(lidar_features)
            lidar_embd_layer1 = self.avgpool(lidar_embd_layer1)
            
            curr_h, curr_w = image_embd_layer1.shape[-2:]
            
            # project image features to bev
            bev_points_layer1 = bev_points.view(bz*curr_h*curr_w*5, 2)
            bev_encoding_layer1 = image_embd_layer1.permute(0,2,3,1).contiguous()[:,bev_points_layer1[:,0],bev_points_layer1[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            bev_encoding_layer1 = torch.diagonal(bev_encoding_layer1, 0).permute(4,3,0,1,2).contiguous()
            bev_encoding_layer1 = torch.sum(bev_encoding_layer1, -1)
            bev_encoding_layer1 = self.image_projection1(bev_encoding_layer1.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            lidar_features_layer1 = F.interpolate(bev_encoding_layer1, scale_factor=8, mode='bilinear')
            lidar_features_layer1 = self.lidar_deconv1(lidar_features_layer1)
            lidar_features = lidar_features + lidar_features_layer1

            # project bev features to image
            img_points_layer1 = img_points.view(bz*curr_h*curr_w*5, 2)
            img_encoding_layer1 = lidar_embd_layer1.permute(0,2,3,1).contiguous()[:,img_points_layer1[:,0],img_points_layer1[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            img_encoding_layer1 = torch.diagonal(img_encoding_layer1, 0).permute(4,3,0,1,2).contiguous()
            img_encoding_layer1 = torch.sum(img_encoding_layer1, -1)
            img_encoding_layer1 = self.lidar_projection1(img_encoding_layer1.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            image_features_layer1 = F.interpolate(img_encoding_layer1, scale_factor=8, mode='bilinear')
            image_features_layer1 = self.image_deconv1(image_features_layer1)
            image_features = image_features + image_features_layer1

        image_features = self.image_encoder.features.layer2(image_features.contiguous())
        lidar_features = self.lidar_encoder._model.layer2(lidar_features.contiguous())
        if self.config.n_scale >= 3:
            # fusion at (B, 128, 32, 32)
            image_embd_layer2 = self.image_conv2(image_features)
            image_embd_layer2 = self.avgpool(image_embd_layer2)
            lidar_embd_layer2 = self.lidar_conv2(lidar_features)
            lidar_embd_layer2 = self.avgpool(lidar_embd_layer2)

            curr_h, curr_w = image_embd_layer2.shape[-2:]
            
            # project image features to bev
            bev_points_layer2 = bev_points.view(bz*curr_h*curr_w*5, 2)
            bev_encoding_layer2 = image_embd_layer2.permute(0,2,3,1).contiguous()[:,bev_points_layer2[:,0],bev_points_layer2[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            bev_encoding_layer2 = torch.diagonal(bev_encoding_layer2, 0).permute(4,3,0,1,2).contiguous()
            bev_encoding_layer2 = torch.sum(bev_encoding_layer2, -1)
            bev_encoding_layer2 = self.image_projection1(bev_encoding_layer2.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            lidar_features_layer2 = F.interpolate(bev_encoding_layer2, scale_factor=4, mode='bilinear')
            lidar_features_layer2 = self.lidar_deconv2(lidar_features_layer2)
            lidar_features = lidar_features + lidar_features_layer2

            # project bev features to image
            img_points_layer2 = img_points.view(bz*curr_h*curr_w*5, 2)
            img_encoding_layer2 = lidar_embd_layer2.permute(0,2,3,1).contiguous()[:,img_points_layer2[:,0],img_points_layer2[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            img_encoding_layer2 = torch.diagonal(img_encoding_layer2, 0).permute(4,3,0,1,2).contiguous()
            img_encoding_layer2 = torch.sum(img_encoding_layer2, -1)
            img_encoding_layer2 = self.lidar_projection2(img_encoding_layer2.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            image_features_layer2 = F.interpolate(img_encoding_layer2, scale_factor=4, mode='bilinear')
            image_features_layer2 = self.image_deconv2(image_features_layer2)
            image_features = image_features + image_features_layer2

        image_features = self.image_encoder.features.layer3(image_features.contiguous())
        lidar_features = self.lidar_encoder._model.layer3(lidar_features.contiguous())
        if self.config.n_scale >= 2:
            # fusion at (B, 256, 16, 16)
            image_embd_layer3 = self.image_conv3(image_features)
            image_embd_layer3 = self.avgpool(image_embd_layer3)
            lidar_embd_layer3 = self.lidar_conv3(lidar_features)
            lidar_embd_layer3 = self.avgpool(lidar_embd_layer3)

            curr_h, curr_w = image_embd_layer3.shape[-2:]
            
            # project image features to bev
            bev_points_layer3 = bev_points.view(bz*curr_h*curr_w*5, 2)
            bev_encoding_layer3 = image_embd_layer3.permute(0,2,3,1).contiguous()[:,bev_points_layer3[:,0],bev_points_layer3[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            bev_encoding_layer3 = torch.diagonal(bev_encoding_layer3, 0).permute(4,3,0,1,2).contiguous()
            bev_encoding_layer3 = torch.sum(bev_encoding_layer3, -1)
            bev_encoding_layer3 = self.image_projection3(bev_encoding_layer3.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            lidar_features_layer3 = F.interpolate(bev_encoding_layer3, scale_factor=2, mode='bilinear')
            lidar_features_layer3 = self.lidar_deconv3(lidar_features_layer3)
            lidar_features = lidar_features + lidar_features_layer3

            # project bev features to image
            img_points_layer3 = img_points.view(bz*curr_h*curr_w*5, 2)
            img_encoding_layer3 = lidar_embd_layer3.permute(0,2,3,1).contiguous()[:,img_points_layer3[:,0],img_points_layer3[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            img_encoding_layer3 = torch.diagonal(img_encoding_layer3, 0).permute(4,3,0,1,2).contiguous()
            img_encoding_layer3 = torch.sum(img_encoding_layer3, -1)
            img_encoding_layer3 = self.lidar_projection3(img_encoding_layer3.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            image_features_layer3 = F.interpolate(img_encoding_layer3, scale_factor=2, mode='bilinear')
            image_features_layer3 = self.image_deconv3(image_features_layer3)
            image_features = image_features + image_features_layer3

        image_features = self.image_encoder.features.layer4(image_features.contiguous())
        lidar_features = self.lidar_encoder._model.layer4(lidar_features.contiguous())
        if self.config.n_scale >= 1:
            # fusion at (B, 512, 8, 8)
            image_embd_layer4 = self.image_conv4(image_features)
            image_embd_layer4 = self.avgpool(image_embd_layer4)
            lidar_embd_layer4 = self.lidar_conv4(lidar_features)
            lidar_embd_layer4 = self.avgpool(lidar_embd_layer4)

            curr_h, curr_w = image_embd_layer4.shape[-2:]
            
            # project image features to bev
            bev_points_layer4 = bev_points.view(bz*curr_h*curr_w*5, 2)
            bev_encoding_layer4 = image_embd_layer4.permute(0,2,3,1).contiguous()[:,bev_points_layer4[:,0],bev_points_layer4[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            bev_encoding_layer4 = torch.diagonal(bev_encoding_layer4, 0).permute(4,3,0,1,2).contiguous()
            bev_encoding_layer4 = torch.sum(bev_encoding_layer4, -1)
            bev_encoding_layer4 = self.image_projection4(bev_encoding_layer4.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            lidar_features_layer4 = self.lidar_deconv4(bev_encoding_layer4)
            lidar_features = lidar_features + lidar_features_layer4

            # project bev features to image
            img_points_layer4 = img_points.view(bz*curr_h*curr_w*5, 2)
            img_encoding_layer4 = lidar_embd_layer3.permute(0,2,3,1).contiguous()[:,img_points_layer4[:,0],img_points_layer4[:,1]].view(bz, bz, curr_h, curr_w, 5, -1)
            img_encoding_layer4 = torch.diagonal(img_encoding_layer4, 0).permute(4,3,0,1,2).contiguous()
            img_encoding_layer4 = torch.sum(img_encoding_layer4, -1)
            img_encoding_layer4 = self.lidar_projection4(img_encoding_layer4.permute(0,2,3,1)).permute(0,3,1,2).contiguous()
            image_features_layer4 = self.image_deconv4(img_encoding_layer4)
            image_features = image_features + image_features_layer4

        image_features = self.image_encoder.features.avgpool(image_features)
        image_features = torch.flatten(image_features, 1)
        image_features = image_features.view(bz, self.config.n_views * self.config.seq_len, -1)
        lidar_features = self.lidar_encoder._model.avgpool(lidar_features)
        lidar_features = torch.flatten(lidar_features, 1)
        lidar_features = lidar_features.view(bz, self.config.seq_len, -1)

        fused_features = torch.cat([image_features, lidar_features], dim=1)
        fused_features = torch.sum(fused_features, dim=1)

        return fused_features


class PIDController(object):
    def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20):
        self._K_P = K_P
        self._K_I = K_I
        self._K_D = K_D

        self._window = deque([0 for _ in range(n)], maxlen=n)
        self._max = 0.0
        self._min = 0.0

    def step(self, error):
        self._window.append(error)
        self._max = max(self._max, abs(error))
        self._min = -abs(self._max)

        if len(self._window) >= 2:
            integral = np.mean(self._window)
            derivative = (self._window[-1] - self._window[-2])
        else:
            integral = 0.0
            derivative = 0.0

        return self._K_P * error + self._K_I * integral + self._K_D * derivative


class GeometricFusion(nn.Module):
    '''
    Image + LiDAR feature fusion using geometric projections followed by
    GRU-based waypoint prediction network and PID controller
    '''

    def __init__(self, config, device):
        super().__init__()
        self.device = device
        self.config = config
        self.pred_len = config.pred_len

        self.turn_controller = PIDController(K_P=config.turn_KP, K_I=config.turn_KI, K_D=config.turn_KD, n=config.turn_n)
        self.speed_controller = PIDController(K_P=config.speed_KP, K_I=config.speed_KI, K_D=config.speed_KD, n=config.speed_n)

        self.encoder = Encoder(config).to(self.device)

        self.join = nn.Sequential(
                            nn.Linear(512, 256),
                            nn.ReLU(inplace=True),
                            nn.Linear(256, 128),
                            nn.ReLU(inplace=True),
                            nn.Linear(128, 64),
                            nn.ReLU(inplace=True),
                        ).to(self.device)
        self.decoder = nn.GRUCell(input_size=2, hidden_size=64).to(self.device)
        self.output = nn.Linear(64, 2).to(self.device)
        
    def forward(self, image_list, lidar_list, target_point, velocity, bev_points, cam_points):
        '''
        Predicts waypoint from geometric feature projections of image + LiDAR input
        Args:
            image_list (list): list of input images
            lidar_list (list): list of input LiDAR BEV
            target_point (tensor): goal location registered to ego-frame
            velocity (tensor): input velocity from speedometer
            bev_points (tensor): projected image pixels onto the BEV grid
            cam_points (tensor): projected LiDAR point cloud onto the image space
        '''
        fused_features = self.encoder(image_list, lidar_list, velocity, bev_points, cam_points)
        z = self.join(fused_features)

        output_wp = list()

        # initial input variable to GRU
        x = torch.zeros(size=(z.shape[0], 2), dtype=z.dtype).to(self.device)

        # autoregressive generation of output waypoints
        for _ in range(self.pred_len):
            x_in = x + target_point
            z = self.decoder(x_in, z)
            dx = self.output(z)
            x = dx + x
            output_wp.append(x)

        pred_wp = torch.stack(output_wp, dim=1)

        return pred_wp

    def control_pid(self, waypoints, velocity):
        ''' 
        Predicts vehicle control with a PID controller.
        Args:
            waypoints (tensor): predicted waypoints
            velocity (tensor): speedometer input
        '''
        assert(waypoints.size(0)==1)
        waypoints = waypoints[0].data.cpu().numpy()

        # flip y is (forward is negative in our waypoints)
        waypoints[:,1] *= -1
        speed = velocity[0].data.cpu().numpy()

        aim = (waypoints[1] + waypoints[0]) / 2.0
        angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
        steer = self.turn_controller.step(angle)
        steer = np.clip(steer, -1.0, 1.0)

        desired_speed = np.linalg.norm(waypoints[0] - waypoints[1]) * 2.0
        brake = desired_speed < self.config.brake_speed or (speed / desired_speed) > self.config.brake_ratio

        delta = np.clip(desired_speed - speed, 0.0, self.config.clip_delta)
        throttle = self.speed_controller.step(delta)
        throttle = np.clip(throttle, 0.0, self.config.max_throttle)
        throttle = throttle if not brake else 0.0

        metadata = {
            'speed': float(speed.astype(np.float64)),
            'steer': float(steer),
            'throttle': float(throttle),
            'brake': float(brake),
            'wp_2': tuple(waypoints[1].astype(np.float64)),
            'wp_1': tuple(waypoints[0].astype(np.float64)),
            'desired_speed': float(desired_speed.astype(np.float64)),
            'angle': float(angle.astype(np.float64)),
            'aim': tuple(aim.astype(np.float64)),
            'delta': float(delta.astype(np.float64)),
        }

        return steer, throttle, brake, metadata