Skip to content
Snippets Groups Projects
model.py 4.46 KiB
Newer Older
  • Learn to ignore specific revisions
  • Aditya Prakash's avatar
    Aditya Prakash committed
    import torch 
    from torch import nn
    import torch.nn.functional as F
    from torchvision import models
    
    
    class ImageCNN(nn.Module):
        """ Encoder network for image input list.
        Args:
            c_dim (int): output dimension of the latent embedding
            normalize (bool): whether the input images should be normalized
        """
    
        def __init__(self, c_dim, normalize=False):
            super().__init__()
            self.normalize = normalize
            self.features = models.resnet18(pretrained=True)
            self.features.fc = nn.Sequential()
    
        def forward(self, inputs):
            c = 0
            for x in inputs:
                if self.normalize:
                    x = normalize_imagenet(x)
                c += self.features(x)
            return c
    
    def normalize_imagenet(x):
        """ Normalize input images according to ImageNet standards.
        Args:
            x (tensor): input images
        """
        x = x.clone()
        x[:, 0] = (x[:, 0] - 0.485) / 0.229
        x[:, 1] = (x[:, 1] - 0.456) / 0.224
        x[:, 2] = (x[:, 2] - 0.406) / 0.225
        return x
    
    
    class Controller(nn.Module):
        """ Decoder with velocity input, velocity prediction and conditional control outputs.
        Args:
            num_branch (int): number of conditional branches
            dim (int): input dimension
            c_dim (int): dimension of latent conditioned code c
            hidden_size (int): hidden size of each decoder branch
            input_velocity (bool): whether to add input velocity information to encoding
            predict_velocity (bool): whether to output a velocity branch prediction
        """
    
        def __init__(self, num_branch=6, dim=1, c_dim=512, hidden_size=256, 
                                    input_velocity=True, predict_velocity=True):
            super().__init__()
            self.num_branch = num_branch
            self.input_velocity = input_velocity
            self.predict_velocity = predict_velocity
            
            # Project input velocity measurement to feature size
            if input_velocity:
                self.vel_in = nn.Sequential(
                    nn.Linear(dim, hidden_size),
                    nn.ReLU(True),
                    nn.Linear(hidden_size, c_dim),
                )
            
            # Project feature to velocity prediction
            if predict_velocity:
                self.vel_out = nn.Sequential(
                    nn.Linear(c_dim, hidden_size),
                    nn.ReLU(True),
                    nn.Linear(hidden_size, hidden_size),
                    nn.ReLU(True),
                    nn.Linear(hidden_size, dim),
                )
    
            # Control branches
            fc_branch_list = []
            for i in range(num_branch):
                fc_branch_list.append(nn.Sequential(
                nn.Linear(c_dim, hidden_size),
                nn.ReLU(True),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(True),
                nn.Linear(hidden_size, 3),
                nn.Sigmoid(),
            ))
    
            self.branches = nn.ModuleList(fc_branch_list)
    
        def forward(self, c, velocity, command):
            batch_size = c.size(0)
            encoding = c
    
            if self.input_velocity:
                encoding += self.vel_in(velocity.unsqueeze(1))
    
            control_pred = 0.
            for i, branch in enumerate(self.branches):
                # Choose control for branch of only active command
                # We check for (command - 1) since navigational command 0 is ignored
                control_pred += branch(encoding) * (i == (command - 1)).unsqueeze(1).expand(batch_size,3)
            
            if self.predict_velocity:
                velocity_pred = self.vel_out(c)
                return control_pred, velocity_pred
    
            return control_pred
    
    
    class CILRS(nn.Module):
    
        def __init__(self, config, device):
            super().__init__()
            self.config = config
            self.device = device
            self.encoder = ImageCNN(512, normalize=True).to(self.device)
            self.controller = Controller(num_branch=6, dim=1, c_dim=512, hidden_size=256, 
                                    input_velocity=True, predict_velocity=True).to(self.device)
    
        def forward(self, c, velocity, command):
            ''' Predicts vehicle control.
            Args:
                c (tensor): latent conditioned code c
                velocity (tensor): speedometer input
                command (tensor): high-level navigational command
            '''
            c = sum(c)
            control_pred, velocity_pred = self.controller(c, velocity, command)
            steer = control_pred[:,0] * 2 - 1. # convert from [0,1] to [-1,1]
            throttle = control_pred[:,1] * self.config.max_throttle
            brake = control_pred[:,2]
            return steer, throttle, brake, velocity_pred