from os import getgrouplist import torch.nn as nn import torch import numpy as np import torch.nn.functional as F from typing import Optional, List from torch import Tensor from torch.nn.init import xavier_uniform_, zeros_, kaiming_normal_ class PointNetfeat(nn.Module): def __init__(self, input_dim, x=1,outchannel=512): super(PointNetfeat, self).__init__() if outchannel==256: self.output_channel = 256 else: self.output_channel = 512 * x self.conv1 = torch.nn.Conv1d(input_dim, 64 * x, 1) self.conv2 = torch.nn.Conv1d(64 * x, 128 * x, 1) self.conv3 = torch.nn.Conv1d(128 * x, 256 * x, 1) self.conv4 = torch.nn.Conv1d(256 * x, self.output_channel, 1) self.bn1 = nn.BatchNorm1d(64 * x) self.bn2 = nn.BatchNorm1d(128 * x) self.bn3 = nn.BatchNorm1d(256 * x) self.bn4 = nn.BatchNorm1d(self.output_channel) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x_ori = self.bn4(self.conv4(x)) x = torch.max(x_ori, 2, keepdim=True)[0] x = x.view(-1, self.output_channel) return x, x_ori class PointNet(nn.Module): def __init__(self, input_dim, joint_feat=False,model_cfg=None): super(PointNet, self).__init__() self.joint_feat = joint_feat channels = model_cfg.TRANS_INPUT times=1 self.feat = PointNetfeat(input_dim, 1) self.fc1 = nn.Linear(512, 256 ) self.fc2 = nn.Linear(256, channels) self.pre_bn = nn.BatchNorm1d(input_dim) self.bn1 = nn.BatchNorm1d(256) self.bn2 = nn.BatchNorm1d(channels) self.relu = nn.ReLU() self.fc_s1 = nn.Linear(channels*times, 256) self.fc_s2 = nn.Linear(256, 3, bias=False) self.fc_ce1 = nn.Linear(channels*times, 256) self.fc_ce2 = nn.Linear(256, 3, bias=False) self.fc_hr1 = nn.Linear(channels*times, 256) self.fc_hr2 = nn.Linear(256, 1, bias=False) def forward(self, x, feat=None): if self.joint_feat: if len(feat.shape) > 2: feat = torch.max(feat, 2, keepdim=True)[0] x = feat.view(-1, self.output_channel) x = F.relu(self.bn1(self.fc1(x))) feat = F.relu(self.bn2(self.fc2(x))) else: feat = feat feat_traj = None else: x, feat_traj = self.feat(self.pre_bn(x)) x = F.relu(self.bn1(self.fc1(x))) feat = F.relu(self.bn2(self.fc2(x))) x = F.relu(self.fc_ce1(feat)) centers = self.fc_ce2(x) x = F.relu(self.fc_s1(feat)) sizes = self.fc_s2(x) x = F.relu(self.fc_hr1(feat)) headings = self.fc_hr2(x) return torch.cat([centers, sizes, headings],-1),feat,feat_traj def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): kaiming_normal_(m.weight.data) if m.bias is not None: zeros_(m.bias) class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class SpatialMixerBlock(nn.Module): def __init__(self,hidden_dim,grid_size,channels,config=None,dropout=0.0): super().__init__() self.mixer_x = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) self.mixer_y = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) self.mixer_z = MLP(input_dim = grid_size, hidden_dim = hidden_dim, output_dim = grid_size, num_layers = 3) self.norm_x = nn.LayerNorm(channels) self.norm_y = nn.LayerNorm(channels) self.norm_z = nn.LayerNorm(channels) self.norm_channel = nn.LayerNorm(channels) self.ffn = nn.Sequential( nn.Linear(channels, 2*channels), nn.ReLU(), nn.Dropout(dropout), nn.Linear(2*channels, channels), ) self.config = config self.grid_size = grid_size def forward(self, src): src_3d = src.permute(1,2,0).contiguous().view(src.shape[1],src.shape[2], self.grid_size,self.grid_size,self.grid_size) src_3d = src_3d.permute(0,1,4,3,2).contiguous() mixed_x = self.mixer_x(src_3d) mixed_x = src_3d + mixed_x mixed_x = self.norm_x(mixed_x.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() mixed_y = self.mixer_y(mixed_x.permute(0,1,2,4,3)).permute(0,1,2,4,3).contiguous() mixed_y = mixed_x + mixed_y mixed_y = self.norm_y(mixed_y.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() mixed_z = self.mixer_z(mixed_y.permute(0,1,4,3,2)).permute(0,1,4,3,2).contiguous() mixed_z = mixed_y + mixed_z mixed_z = self.norm_z(mixed_z.permute(0,2,3,4,1)).permute(0,4,1,2,3).contiguous() src_mixer = mixed_z.view(src.shape[1],src.shape[2],-1).permute(2,0,1) src_mixer = src_mixer + self.ffn(src_mixer) src_mixer = self.norm_channel(src_mixer) return src_mixer class Transformer(nn.Module): def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, num_lidar_points=None,num_proxy_points=None, share_head=True,num_groups=None, sequence_stride=None,num_frames=None): super().__init__() self.config = config self.share_head = share_head self.num_frames = num_frames self.nhead = nhead self.sequence_stride = sequence_stride self.num_groups = num_groups self.num_proxy_points = num_proxy_points self.num_lidar_points = num_lidar_points self.d_model = d_model self.nhead = nhead encoder_layer = [TransformerEncoderLayer(self.config, d_model, nhead, dim_feedforward,dropout, activation, normalize_before, num_lidar_points,num_groups=num_groups) for i in range(num_encoder_layers)] encoder_norm = nn.LayerNorm(d_model) if normalize_before else None self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm,self.config) self.token = nn.Parameter(torch.zeros(self.num_groups, 1, d_model)) if self.num_frames >4: self.group_length = self.num_frames // self.num_groups self.fusion_all_group = MLP(input_dim = self.config.hidden_dim*self.group_length, hidden_dim = self.config.hidden_dim, output_dim = self.config.hidden_dim, num_layers = 4) self.fusion_norm = FFN(d_model, dim_feedforward) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, src, pos=None): BS, N, C = src.shape if not pos is None: pos = pos.permute(1, 0, 2) if self.num_frames == 16: token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] if self.sequence_stride ==1: src_groups = src.view(src.shape[0],src.shape[1]//self.num_groups ,-1).chunk(4,dim=1) elif self.sequence_stride ==4: src_groups = [] for i in range(self.num_groups): groups = [] for j in range(self.group_length): points_index_start = (i+j*self.sequence_stride)*self.num_proxy_points points_index_end = points_index_start + self.num_proxy_points groups.append(src[:,points_index_start:points_index_end]) groups = torch.cat(groups,-1) src_groups.append(groups) else: raise NotImplementedError src_merge = torch.cat(src_groups,1) src = self.fusion_norm(src[:,:self.num_groups*self.num_proxy_points],self.fusion_all_group(src_merge)) src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] src = torch.cat(src,dim=0) else: token_list = [self.token[i:(i+1)].repeat(BS,1,1) for i in range(self.num_groups)] src = [torch.cat([token_list[i],src[:,i*self.num_proxy_points:(i+1)*self.num_proxy_points]],dim=1) for i in range(self.num_groups)] src = torch.cat(src,dim=0) src = src.permute(1, 0, 2) memory,tokens = self.encoder(src,pos=pos) memory = torch.cat(memory[0:1].chunk(4,dim=1),0) return memory, tokens class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None,config=None): super().__init__() self.layers = nn.ModuleList(encoder_layer) self.num_layers = num_layers self.norm = norm self.config = config def forward(self, src, pos: Optional[Tensor] = None): token_list = [] output = src for layer in self.layers: output,tokens = layer(output,pos=pos) token_list.append(tokens) if self.norm is not None: output = self.norm(output) return output,token_list class TransformerEncoderLayer(nn.Module): count = 0 def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False,num_points=None,num_groups=None): super().__init__() TransformerEncoderLayer.count += 1 self.layer_count = TransformerEncoderLayer.count self.config = config self.num_point = num_points self.num_groups= num_groups self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if self.layer_count <= self.config.enc_layers-1: self.cross_attn_layers = nn.ModuleList() for _ in range(self.num_groups): self.cross_attn_layers.append(nn.MultiheadAttention(d_model, nhead, dropout=dropout)) self.ffn = FFN(d_model, dim_feedforward) self.fusion_all_groups = MLP(input_dim = d_model*4, hidden_dim = d_model, output_dim = d_model, num_layers = 4) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self.mlp_mixer_3d = SpatialMixerBlock(self.config.use_mlp_mixer.hidden_dim,self.config.use_mlp_mixer.get('grid_size', 4),self.config.hidden_dim, self.config.use_mlp_mixer) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, src, pos: Optional[Tensor] = None): src_intra_group_fusion = self.mlp_mixer_3d(src[1:]) src = torch.cat([src[:1],src_intra_group_fusion],0) token = src[:1] if not pos is None: key = self.with_pos_embed(src_intra_group_fusion, pos[1:]) else: key = src_intra_group_fusion src_summary = self.self_attn(token, key, value=src_intra_group_fusion)[0] token = token + self.dropout1(src_summary) token = self.norm1(token) src_summary = self.linear2(self.dropout(self.activation(self.linear1(token)))) token = token + self.dropout2(src_summary) token = self.norm2(token) src = torch.cat([token,src[1:]],0) if self.layer_count <= self.config.enc_layers-1: src_all_groups = src[1:].view((src.shape[0]-1)*4,-1,src.shape[-1]) src_groups_list = src_all_groups.chunk(self.num_groups,0) src_all_groups = torch.cat(src_groups_list,-1) src_all_groups_fusion = self.fusion_all_groups(src_all_groups) key = self.with_pos_embed(src_all_groups_fusion, pos[1:]) query_list = [self.with_pos_embed(query, pos[1:]) for query in src_groups_list] inter_group_fusion_list = [] for i in range(self.num_groups): inter_group_fusion = self.cross_attn_layers[i](query_list[i], key, value=src_all_groups_fusion)[0] inter_group_fusion = self.ffn(src_groups_list[i],inter_group_fusion) inter_group_fusion_list.append(inter_group_fusion) src_inter_group_fusion = torch.cat(inter_group_fusion_list,1) src = torch.cat([src[:1],src_inter_group_fusion],0) return src, torch.cat(src[:1].chunk(4,1),0) def forward_pre(self, src, pos: Optional[Tensor] = None): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn(q, k, value=src2)[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward(self, src, pos: Optional[Tensor] = None): if self.normalize_before: return self.forward_pre(src, pos) return self.forward_post(src, pos) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(F"activation should be relu/gelu, not {activation}.") class FFN(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.1,dout=None, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def forward(self, tgt,tgt_input): tgt = tgt + self.dropout2(tgt_input) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def build_transformer(args): return Transformer( config = args, d_model=args.hidden_dim, dropout=args.dropout, nhead=args.nheads, dim_feedforward=args.dim_feedforward, num_encoder_layers=args.enc_layers, normalize_before=args.pre_norm, num_lidar_points = args.num_lidar_points, num_proxy_points = args.num_proxy_points, num_frames = args.num_frames, sequence_stride = args.get('sequence_stride',1), num_groups=args.num_groups, )