From 5a2ff56ca5536c028fad8a59c0589bfb90970889 Mon Sep 17 00:00:00 2001
From: Aditya Prakash <adityaprakash229997@gmail.com>
Date: Sat, 25 Sep 2021 17:55:02 -0500
Subject: [PATCH] attention map visualizations

---
 transfuser/config.py    |   7 +
 transfuser/model_viz.py | 509 ++++++++++++++++++++++++++++++++++++++++
 transfuser/viz.py       | 163 +++++++++++++
 3 files changed, 679 insertions(+)
 create mode 100644 transfuser/model_viz.py
 create mode 100644 transfuser/viz.py

diff --git a/transfuser/config.py b/transfuser/config.py
index 5b9d33c..091ebc2 100644
--- a/transfuser/config.py
+++ b/transfuser/config.py
@@ -16,6 +16,13 @@ class GlobalConfig:
     for town in val_towns:
         val_data.append(os.path.join(root_dir, town+'_short'))
 
+    # visualizing transformer attention maps
+    viz_root = '/mnt/qb/geiger/kchitta31/data_06_21'
+    viz_towns = ['Town05_tiny']
+    viz_data = []
+    for town in viz_towns:
+        viz_data.append(os.path.join(viz_root, town))
+
     ignore_sides = True # don't consider side cameras
     ignore_rear = True # don't consider rear cameras
     n_views = 1 # no. of camera views
diff --git a/transfuser/model_viz.py b/transfuser/model_viz.py
new file mode 100644
index 0000000..f0d6b48
--- /dev/null
+++ b/transfuser/model_viz.py
@@ -0,0 +1,509 @@
+import math
+from collections import deque
+
+import numpy as np
+import torch 
+from torch import nn
+import torch.nn.functional as F
+from torchvision import models
+
+
+class ImageCNN(nn.Module):
+    """ 
+    Encoder network for image input list.
+    Args:
+        c_dim (int): output dimension of the latent embedding
+        normalize (bool): whether the input images should be normalized
+    """
+
+    def __init__(self, c_dim, normalize=True):
+        super().__init__()
+        self.normalize = normalize
+        self.features = models.resnet34(pretrained=True)
+        self.features.fc = nn.Sequential()
+
+    def forward(self, inputs):
+        c = 0
+        for x in inputs:
+            if self.normalize:
+                x = normalize_imagenet(x)
+            c += self.features(x)
+        return c
+
+def normalize_imagenet(x):
+    """ Normalize input images according to ImageNet standards.
+    Args:
+        x (tensor): input images
+    """
+    x = x.clone()
+    x[:, 0] = (x[:, 0] - 0.485) / 0.229
+    x[:, 1] = (x[:, 1] - 0.456) / 0.224
+    x[:, 2] = (x[:, 2] - 0.406) / 0.225
+    return x
+
+
+class LidarEncoder(nn.Module):
+    """
+    Encoder network for LiDAR input list
+    Args:
+        num_classes: output feature dimension
+        in_channels: input channels
+    """
+
+    def __init__(self, num_classes=512, in_channels=2):
+        super().__init__()
+
+        self._model = models.resnet18()
+        self._model.fc = nn.Sequential()
+        _tmp = self._model.conv1
+        self._model.conv1 = nn.Conv2d(in_channels, out_channels=_tmp.out_channels, 
+            kernel_size=_tmp.kernel_size, stride=_tmp.stride, padding=_tmp.padding, bias=_tmp.bias)
+
+    def forward(self, inputs):
+        features = 0
+        for lidar_data in inputs:
+            lidar_feature = self._model(lidar_data)
+            features += lidar_feature
+
+        return features
+
+
+class SelfAttention(nn.Module):
+    """
+    A vanilla multi-head masked self-attention layer with a projection at the end.
+    """
+
+    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop):
+        super().__init__()
+        assert n_embd % n_head == 0
+        # key, query, value projections for all heads
+        self.key = nn.Linear(n_embd, n_embd)
+        self.query = nn.Linear(n_embd, n_embd)
+        self.value = nn.Linear(n_embd, n_embd)
+        # regularization
+        self.attn_drop = nn.Dropout(attn_pdrop)
+        self.resid_drop = nn.Dropout(resid_pdrop)
+        # output projection
+        self.proj = nn.Linear(n_embd, n_embd)
+        self.n_head = n_head
+
+    def forward(self, x):
+        B, T, C = x.size()
+
+        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+        # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+        att = F.softmax(att, dim=-1)
+        att = self.attn_drop(att)
+        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+        # output projection
+        y = self.resid_drop(self.proj(y))
+        return y, att
+
+
+class Block(nn.Module):
+    """ an unassuming Transformer block """
+
+    def __init__(self, n_embd, n_head, block_exp, attn_pdrop, resid_pdrop):
+        super().__init__()
+        self.ln1 = nn.LayerNorm(n_embd)
+        self.ln2 = nn.LayerNorm(n_embd)
+        self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop)
+        self.mlp = nn.Sequential(
+            nn.Linear(n_embd, block_exp * n_embd),
+            nn.ReLU(True), # changed from GELU
+            nn.Linear(block_exp * n_embd, n_embd),
+            nn.Dropout(resid_pdrop),
+        )
+
+    def forward(self, in_tuple):
+        x, _ = in_tuple
+        B, T, C = x.size()
+
+        x_up, att = self.attn(self.ln1(x))
+        x = x + x_up
+        x = x + self.mlp(self.ln2(x))
+
+        return (x, att)
+
+
+class GPT(nn.Module):
+    """  the full GPT language model, with a context size of block_size """
+
+    def __init__(self, n_embd, n_head, block_exp, n_layer, 
+                    vert_anchors, horz_anchors, seq_len, 
+                    embd_pdrop, attn_pdrop, resid_pdrop, config):
+        super().__init__()
+        self.n_embd = n_embd
+        self.seq_len = seq_len
+        self.vert_anchors = vert_anchors
+        self.horz_anchors = horz_anchors
+        self.config = config
+
+        # positional embedding parameter (learnable), image + lidar
+        self.pos_emb = nn.Parameter(torch.zeros(1, (self.config.n_views + 1) * seq_len * vert_anchors * horz_anchors, n_embd))
+        
+        # velocity embedding
+        self.vel_emb = nn.Linear(1, n_embd)
+        self.drop = nn.Dropout(embd_pdrop)
+
+        # transformer
+        self.blocks = nn.Sequential(*[Block(n_embd, n_head, 
+                        block_exp, attn_pdrop, resid_pdrop)
+                        for layer in range(n_layer)])
+        
+        # decoder head
+        self.ln_f = nn.LayerNorm(n_embd)
+
+        self.block_size = seq_len
+        self.apply(self._init_weights)
+
+    def get_block_size(self):
+        return self.block_size
+
+    def _init_weights(self, module):
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def configure_optimizers(self):
+        # separate out all parameters to those that will and won't experience regularizing weight decay
+        decay = set()
+        no_decay = set()
+        whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
+        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.BatchNorm2d)
+        for mn, m in self.named_modules():
+            for pn, p in m.named_parameters():
+                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+                if pn.endswith('bias'):
+                    # all biases will not be decayed
+                    no_decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+                    # weights of whitelist modules will be weight decayed
+                    decay.add(fpn)
+                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+                    # weights of blacklist modules will NOT be weight decayed
+                    no_decay.add(fpn)
+
+        # special case the position embedding parameter in the root GPT module as not decayed
+        no_decay.add('pos_emb')
+
+        # create the pytorch optimizer object
+        param_dict = {pn: p for pn, p in self.named_parameters()}
+        optim_groups = [
+            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+        ]
+
+        return optim_groups
+
+    def forward(self, image_tensor, lidar_tensor, velocity):
+        """
+        Args:
+            image_tensor (tensor): B*4*seq_len, C, H, W
+            lidar_tensor (tensor): B*seq_len, C, H, W
+            velocity (tensor): ego-velocity
+        """
+        
+        bz = lidar_tensor.shape[0] // self.seq_len
+        h, w = lidar_tensor.shape[2:4]
+        
+        # forward the image model for token embeddings
+        image_tensor = image_tensor.view(bz, self.config.n_views * self.seq_len, -1, h, w)
+        lidar_tensor = lidar_tensor.view(bz, self.seq_len, -1, h, w)
+
+        # pad token embeddings along number of tokens dimension
+        token_embeddings = torch.cat([image_tensor, lidar_tensor], dim=1).permute(0,1,3,4,2).contiguous()
+        token_embeddings = token_embeddings.view(bz, -1, self.n_embd) # (B, an * T, C)
+
+        # project velocity to n_embed
+        velocity_embeddings = self.vel_emb(velocity.unsqueeze(1)) # (B, C)
+
+        # add (learnable) positional embedding and velocity embedding for all tokens
+        x = self.drop(self.pos_emb + token_embeddings + velocity_embeddings.unsqueeze(1)) # (B, an * T, C)
+        # x = self.drop(token_embeddings + velocity_embeddings.unsqueeze(1)) # (B, an * T, C)
+        x, attn_map = self.blocks((x, None)) # (B, an * T, C)
+        x = self.ln_f(x) # (B, an * T, C)
+        x = x.view(bz, (self.config.n_views + 1) * self.seq_len, self.vert_anchors, self.horz_anchors, self.n_embd)
+        x = x.permute(0,1,4,2,3).contiguous() # same as token_embeddings
+
+        image_tensor_out = x[:, :self.config.n_views*self.seq_len, :, :, :].contiguous().view(bz * self.config.n_views * self.seq_len, -1, h, w)
+        lidar_tensor_out = x[:, self.config.n_views*self.seq_len:, :, :, :].contiguous().view(bz * self.seq_len, -1, h, w)
+        
+        return image_tensor_out, lidar_tensor_out, attn_map
+
+
+class Encoder(nn.Module):
+    """
+    Multi-view Multi-scale Fusion Transformer for image + LiDAR feature fusion
+    """
+
+    def __init__(self, config, **kwargs):
+        super(Encoder, self).__init__()
+        self.config = config
+
+        self.avgpool = nn.AdaptiveAvgPool2d((self.config.vert_anchors, self.config.horz_anchors))
+        
+        self.image_encoder = ImageCNN(512, normalize=True)
+        self.lidar_encoder = LidarEncoder(num_classes=512, in_channels=2)
+
+        self.transformer1 = GPT(n_embd=64,
+                            n_head=config.n_head, 
+                            block_exp=config.block_exp, 
+                            n_layer=config.n_layer, 
+                            vert_anchors=config.vert_anchors, 
+                            horz_anchors=config.horz_anchors, 
+                            seq_len=config.seq_len, 
+                            embd_pdrop=config.embd_pdrop, 
+                            attn_pdrop=config.attn_pdrop, 
+                            resid_pdrop=config.resid_pdrop,
+                            config=config)
+        self.transformer2 = GPT(n_embd=128,
+                            n_head=config.n_head, 
+                            block_exp=config.block_exp, 
+                            n_layer=config.n_layer, 
+                            vert_anchors=config.vert_anchors, 
+                            horz_anchors=config.horz_anchors, 
+                            seq_len=config.seq_len, 
+                            embd_pdrop=config.embd_pdrop, 
+                            attn_pdrop=config.attn_pdrop, 
+                            resid_pdrop=config.resid_pdrop,
+                            config=config)
+        self.transformer3 = GPT(n_embd=256,
+                            n_head=config.n_head, 
+                            block_exp=config.block_exp, 
+                            n_layer=config.n_layer, 
+                            vert_anchors=config.vert_anchors, 
+                            horz_anchors=config.horz_anchors, 
+                            seq_len=config.seq_len, 
+                            embd_pdrop=config.embd_pdrop, 
+                            attn_pdrop=config.attn_pdrop, 
+                            resid_pdrop=config.resid_pdrop,
+                            config=config)
+        self.transformer4 = GPT(n_embd=512,
+                            n_head=config.n_head, 
+                            block_exp=config.block_exp, 
+                            n_layer=config.n_layer, 
+                            vert_anchors=config.vert_anchors, 
+                            horz_anchors=config.horz_anchors, 
+                            seq_len=config.seq_len, 
+                            embd_pdrop=config.embd_pdrop, 
+                            attn_pdrop=config.attn_pdrop, 
+                            resid_pdrop=config.resid_pdrop,
+                            config=config)
+
+        
+    def forward(self, image_list, lidar_list, velocity):
+        if self.image_encoder.normalize:
+            image_list = [normalize_imagenet(image_input) for image_input in image_list]
+
+        bz, _, h, w = lidar_list[0].shape
+        img_channel = image_list[0].shape[1]
+        lidar_channel = lidar_list[0].shape[1]
+        self.config.n_views = len(image_list) // self.config.seq_len
+
+        image_tensor = torch.stack(image_list, dim=1).view(bz * self.config.n_views * self.config.seq_len, img_channel, h, w)
+        lidar_tensor = torch.stack(lidar_list, dim=1).view(bz * self.config.seq_len, lidar_channel, h, w)
+
+        image_features = self.image_encoder.features.conv1(image_tensor)
+        image_features = self.image_encoder.features.bn1(image_features)
+        image_features = self.image_encoder.features.relu(image_features)
+        image_features = self.image_encoder.features.maxpool(image_features)
+        lidar_features = self.lidar_encoder._model.conv1(lidar_tensor)
+        lidar_features = self.lidar_encoder._model.bn1(lidar_features)
+        lidar_features = self.lidar_encoder._model.relu(lidar_features)
+        lidar_features = self.lidar_encoder._model.maxpool(lidar_features)
+
+        image_features = self.image_encoder.features.layer1(image_features)
+        lidar_features = self.lidar_encoder._model.layer1(lidar_features)
+        # fusion at (B, 64, 64, 64)
+        image_embd_layer1 = self.avgpool(image_features)
+        lidar_embd_layer1 = self.avgpool(lidar_features)
+        image_features_layer1, lidar_features_layer1, attn_map1 = self.transformer1(image_embd_layer1, lidar_embd_layer1, velocity)
+        image_features_layer1 = F.interpolate(image_features_layer1, scale_factor=8, mode='bilinear')
+        lidar_features_layer1 = F.interpolate(lidar_features_layer1, scale_factor=8, mode='bilinear')
+        image_features = image_features + image_features_layer1
+        lidar_features = lidar_features + lidar_features_layer1
+
+        image_features = self.image_encoder.features.layer2(image_features)
+        lidar_features = self.lidar_encoder._model.layer2(lidar_features)
+        # fusion at (B, 128, 32, 32)
+        image_embd_layer2 = self.avgpool(image_features)
+        lidar_embd_layer2 = self.avgpool(lidar_features)
+        image_features_layer2, lidar_features_layer2, attn_map2 = self.transformer2(image_embd_layer2, lidar_embd_layer2, velocity)
+        image_features_layer2 = F.interpolate(image_features_layer2, scale_factor=4, mode='bilinear')
+        lidar_features_layer2 = F.interpolate(lidar_features_layer2, scale_factor=4, mode='bilinear')
+        image_features = image_features + image_features_layer2
+        lidar_features = lidar_features + lidar_features_layer2
+
+        image_features = self.image_encoder.features.layer3(image_features)
+        lidar_features = self.lidar_encoder._model.layer3(lidar_features)
+        # fusion at (B, 256, 16, 16)
+        image_embd_layer3 = self.avgpool(image_features)
+        lidar_embd_layer3 = self.avgpool(lidar_features)
+        image_features_layer3, lidar_features_layer3, attn_map3 = self.transformer3(image_embd_layer3, lidar_embd_layer3, velocity)
+        image_features_layer3 = F.interpolate(image_features_layer3, scale_factor=2, mode='bilinear')
+        lidar_features_layer3 = F.interpolate(lidar_features_layer3, scale_factor=2, mode='bilinear')
+        image_features = image_features + image_features_layer3
+        lidar_features = lidar_features + lidar_features_layer3
+
+        image_features = self.image_encoder.features.layer4(image_features)
+        lidar_features = self.lidar_encoder._model.layer4(lidar_features)
+        # fusion at (B, 512, 8, 8)
+        image_embd_layer4 = self.avgpool(image_features)
+        lidar_embd_layer4 = self.avgpool(lidar_features)
+        image_features_layer4, lidar_features_layer4, attn_map4 = self.transformer4(image_embd_layer4, lidar_embd_layer4, velocity)
+        image_features = image_features + image_features_layer4
+        lidar_features = lidar_features + lidar_features_layer4
+
+        image_features = self.image_encoder.features.avgpool(image_features)
+        image_features = torch.flatten(image_features, 1)
+        image_features = image_features.view(bz, self.config.n_views * self.config.seq_len, -1)
+        lidar_features = self.lidar_encoder._model.avgpool(lidar_features)
+        lidar_features = torch.flatten(lidar_features, 1)
+        lidar_features = lidar_features.view(bz, self.config.seq_len, -1)
+
+        fused_features = torch.cat([image_features, lidar_features], dim=1)
+        fused_features = torch.sum(fused_features, dim=1)
+
+        attn_map = torch.stack([attn_map1, attn_map2, attn_map3, attn_map4], dim=1)
+
+        return fused_features, attn_map
+
+
+class PIDController(object):
+    def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20):
+        self._K_P = K_P
+        self._K_I = K_I
+        self._K_D = K_D
+
+        self._window = deque([0 for _ in range(n)], maxlen=n)
+        self._max = 0.0
+        self._min = 0.0
+
+    def step(self, error):
+        self._window.append(error)
+        self._max = max(self._max, abs(error))
+        self._min = -abs(self._max)
+
+        if len(self._window) >= 2:
+            integral = np.mean(self._window)
+            derivative = (self._window[-1] - self._window[-2])
+        else:
+            integral = 0.0
+            derivative = 0.0
+
+        return self._K_P * error + self._K_I * integral + self._K_D * derivative
+
+
+class TransFuser(nn.Module):
+    '''
+    Transformer-based feature fusion followed by GRU-based waypoint prediction network and PID controller
+    '''
+
+    def __init__(self, config, device):
+        super().__init__()
+        self.device = device
+        self.config = config
+        self.pred_len = config.pred_len
+
+        self.turn_controller = PIDController(K_P=config.turn_KP, K_I=config.turn_KI, K_D=config.turn_KD, n=config.turn_n)
+        self.speed_controller = PIDController(K_P=config.speed_KP, K_I=config.speed_KI, K_D=config.speed_KD, n=config.speed_n)
+
+        self.encoder = Encoder(config).to(self.device)
+
+        self.join = nn.Sequential(
+                            nn.Linear(512, 256),
+                            nn.ReLU(inplace=True),
+                            nn.Linear(256, 128),
+                            nn.ReLU(inplace=True),
+                            nn.Linear(128, 64),
+                            nn.ReLU(inplace=True),
+                        ).to(self.device)
+        self.decoder = nn.GRUCell(input_size=2, hidden_size=64).to(self.device)
+        self.output = nn.Linear(64, 2).to(self.device)
+        
+    def forward(self, image_list, lidar_list, target_point, velocity):
+        '''
+        Predicts waypoint from geometric feature projections of image + LiDAR input
+        Args:
+            image_list (list): list of input images
+            lidar_list (list): list of input LiDAR BEV
+            target_point (tensor): goal location registered to ego-frame
+            velocity (tensor): input velocity from speedometer
+        '''
+        fused_features = self.encoder(image_list, lidar_list, velocity)
+        z = self.join(fused_features)
+
+        output_wp = list()
+
+        # initial input variable to GRU
+        x = torch.zeros(size=(z.shape[0], 2), dtype=z.dtype).to(self.device)
+
+        # autoregressive generation of output waypoints
+        for _ in range(self.pred_len):
+            # x_in = torch.cat([x, target_point], dim=1)
+            x_in = x + target_point
+            z = self.decoder(x_in, z)
+            dx = self.output(z)
+            x = dx + x
+            output_wp.append(x)
+
+        pred_wp = torch.stack(output_wp, dim=1)
+
+        return pred_wp
+
+    def control_pid(self, waypoints, velocity):
+        ''' 
+        Predicts vehicle control with a PID controller.
+        Args:
+            waypoints (tensor): predicted waypoints
+            velocity (tensor): speedometer input
+        '''
+        assert(waypoints.size(0)==1)
+        waypoints = waypoints[0].data.cpu().numpy()
+
+        # flip y is (forward is negative in our waypoints)
+        waypoints[:,1] *= -1
+        speed = velocity[0].data.cpu().numpy()
+
+        desired_speed = np.linalg.norm(waypoints[0] - waypoints[1]) * 2.0
+        brake = desired_speed < self.config.brake_speed or (speed / desired_speed) > self.config.brake_ratio
+
+        aim = (waypoints[1] + waypoints[0]) / 2.0
+        angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
+        if(speed < 0.01):
+            angle = np.array(0.0) # When we don't move we don't want the angle error to accumulate in the integral
+        steer = self.turn_controller.step(angle)
+        steer = np.clip(steer, -1.0, 1.0)
+
+        delta = np.clip(desired_speed - speed, 0.0, self.config.clip_delta)
+        throttle = self.speed_controller.step(delta)
+        throttle = np.clip(throttle, 0.0, self.config.max_throttle)
+        throttle = throttle if not brake else 0.0
+
+        metadata = {
+            'speed': float(speed.astype(np.float64)),
+            'steer': float(steer),
+            'throttle': float(throttle),
+            'brake': float(brake),
+            'wp_2': tuple(waypoints[1].astype(np.float64)),
+            'wp_1': tuple(waypoints[0].astype(np.float64)),
+            'desired_speed': float(desired_speed.astype(np.float64)),
+            'angle': float(angle.astype(np.float64)),
+            'aim': tuple(aim.astype(np.float64)),
+            'delta': float(delta.astype(np.float64)),
+        }
+
+        return steer, throttle, brake, metadata
\ No newline at end of file
diff --git a/transfuser/viz.py b/transfuser/viz.py
new file mode 100644
index 0000000..14c9e8d
--- /dev/null
+++ b/transfuser/viz.py
@@ -0,0 +1,163 @@
+import argparse
+import os
+from tqdm import tqdm
+
+import numpy as np
+from PIL import Image
+import torch
+from torch.utils.data import DataLoader
+torch.backends.cudnn.benchmark = True
+
+from config import GlobalConfig
+from model_viz import TransFuser
+from data import CARLA_Data
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model_path', type=str, required=True, help='path to model ckpt')
+parser.add_argument('--device', type=str, default='cuda', help='Device to use')
+parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
+parser.add_argument('--save_path', type=str, default=None, help='path to save visualizations')
+parser.add_argument('--total_size', type=int, default=1000, help='total images for which to generate visualizations')
+parser.add_argument('--attn_thres', type=int, default=1, help='minimum # tokens of other modality required for global context')
+
+args = parser.parse_args()
+
+# Config
+config = GlobalConfig()
+
+if args.save_path is not None and not os.path.isdir(args.save_path):
+	os.makedirs(args.save_path, exist_ok=True)
+
+# Data
+viz_data = CARLA_Data(root=config.viz_data, config=config)
+dataloader_viz = DataLoader(viz_data, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
+
+# Model
+model = TransFuser(config, args.device)
+
+model_parameters = filter(lambda p: p.requires_grad, model.parameters())
+params = sum([np.prod(p.size()) for p in model_parameters])
+print ('Total parameters: ', params)
+
+model.load_state_dict(torch.load(os.path.join(args.model_path, 'best_model.pth')))
+model.eval()
+
+x = [i for i in range(16, 512, 32)]
+y = [i for i in range(16, 256, 32)]
+patch_centers = []
+for i in x:
+	for j in y:
+		patch_centers.append((i,j))
+
+cnt = 0
+
+# central tokens in both modalities, adjusted for alignment mismatch
+central_image_tokens = list(range(16,40))
+central_lidar_tokens = list(range(4,64,8))+list(range(6,64,8))+list(range(5,64,8))
+global_context = [[], [], [], []]
+
+with torch.no_grad():
+	for enum, data in enumerate(tqdm(dataloader_viz)):
+		
+		if enum*args.batch_size >= args.total_size: # total images for which to generate visualizations
+			break
+		
+		# create batch and move to GPU
+		fronts_in = data['fronts']
+		lidars_in = data['lidars']
+		fronts = []
+		bevs = []
+		lidars = []
+		for i in range(config.seq_len):
+			fronts.append(fronts_in[i].to(args.device, dtype=torch.float32))
+			lidars.append(lidars_in[i].to(args.device, dtype=torch.float32))
+
+		# driving labels
+		command = data['command'].to(args.device)
+		gt_velocity = data['velocity'].to(args.device, dtype=torch.float32)
+
+		# target point
+		target_point = torch.stack(data['target_point'], dim=1).to(args.device, dtype=torch.float32)
+		
+		pred_wp, attn_map = model(fronts, lidars, target_point, gt_velocity)
+
+        # we use 4 attention heads in the model
+		attn_map1 = attn_map[:,0,:,:,:].detach().cpu().numpy()
+		attn_map2 = attn_map[:,1,:,:,:].detach().cpu().numpy()
+		attn_map3 = attn_map[:,2,:,:,:].detach().cpu().numpy()
+		attn_map4 = attn_map[:,3,:,:,:].detach().cpu().numpy()
+
+		curr_cnt = 0
+		for idx in range(args.batch_size):
+			img = np.transpose(data['fronts'][0][idx].numpy(), (1,2,0))
+			lidar_bev = (data['lidar_bevs'][0][idx].squeeze(0).numpy()*255).astype(np.uint8)
+			lidar_bev = np.stack([lidar_bev]*3, 2)
+			combined_img = np.vstack([img, lidar_bev])
+
+			if args.save_path is not None:
+				img_path = os.path.join(args.save_path, str(cnt).zfill(5))
+				if not os.path.isdir(img_path):
+					os.makedirs(img_path, exist_ok=True)
+				Image.fromarray(img).save(os.path.join(img_path, 'input_image.png'))
+				Image.fromarray(np.rot90(lidar_bev, 1, (1,0))).save(os.path.join(img_path, 'input_lidar.png')) # adjust for alignment mismatch
+				
+			cnt += 1
+
+			for head in range(4):
+				curr_attn = attn_map4[idx,head]
+				for token in range(128):
+					attn_vector = curr_attn[token]
+					attn_indices = np.argpartition(attn_vector, -5)[-5:]
+
+					if token in central_image_tokens:
+						if np.sum(attn_indices>=64) >= args.attn_thres:
+							global_context[head].append(1)
+						else:
+							global_context[head].append(0)
+
+					# if token in central_lidar_tokens:
+					# 	if np.sum(attn_indices<64) >= args.attn_thres:
+					# 		global_context[head].append(1)
+					# 	else:
+					# 		global_context[head].append(0)
+
+					if (token<64 and (attn_indices>=64).any()) or (token>=64 and (attn_indices<64).any()):
+						
+						if args.save_path is not None:
+							curr_path = os.path.join(img_path, str(token)+'_'+str(head)+'_'+'_'.join(str(xx) for xx in attn_indices))
+							if not os.path.isdir(curr_path):
+								os.makedirs(curr_path, exist_ok=True)
+
+						tmp_attn = np.zeros((512, 256, 3)).astype(np.uint8)
+						row = patch_centers[token][0]
+						col = patch_centers[token][1]
+						tmp_attn[row-16:row+16, col-16:col+16, :]=1
+						cropped_img = combined_img*tmp_attn
+						if args.save_path is not None:
+							if token<64:
+								Image.fromarray(cropped_img[:256,:,:]).save(os.path.join(curr_path, 'source_token_img.png'))
+							else:
+								Image.fromarray(np.rot90(cropped_img[256:,:,:], 1, (1,0))).save(os.path.join(curr_path, 'source_token_lidar.png'))
+
+						tmp_attn = np.zeros((512, 256, 3)).astype(np.uint8)
+						for attn_token in attn_indices:
+							row = patch_centers[attn_token][0]
+							col = patch_centers[attn_token][1]
+							tmp_attn[row-16:row+16, col-16:col+16, :]=1
+						cropped_img = combined_img*tmp_attn
+						if args.save_path is not None:
+							Image.fromarray(cropped_img[:256,:,:]).save(os.path.join(curr_path, 'attended_token_img.png'))
+							Image.fromarray(np.rot90(cropped_img[256:,:,:], 1, (1,0))).save(os.path.join(curr_path, 'attended_token_lidar.png'))
+
+						curr_cnt += 1
+
+
+global_context = np.array(global_context)
+global_context = np.sum(global_context, 0)
+global_context = global_context>0
+
+valid_tokens = global_context.sum()
+valid_percent = valid_tokens/len(global_context)
+
+print (global_context.sum(), len(global_context), valid_percent)
\ No newline at end of file
-- 
GitLab