Add File
This commit is contained in:
73
pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py
Normal file
73
pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PointPillarScatter(nn.Module):
|
||||
def __init__(self, model_cfg, grid_size, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.model_cfg = model_cfg
|
||||
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
|
||||
self.nx, self.ny, self.nz = grid_size
|
||||
assert self.nz == 1
|
||||
|
||||
def forward(self, batch_dict, **kwargs):
|
||||
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
|
||||
batch_spatial_features = []
|
||||
batch_size = coords[:, 0].max().int().item() + 1
|
||||
for batch_idx in range(batch_size):
|
||||
spatial_feature = torch.zeros(
|
||||
self.num_bev_features,
|
||||
self.nz * self.nx * self.ny,
|
||||
dtype=pillar_features.dtype,
|
||||
device=pillar_features.device)
|
||||
|
||||
batch_mask = coords[:, 0] == batch_idx
|
||||
this_coords = coords[batch_mask, :]
|
||||
indices = this_coords[:, 1] + this_coords[:, 2] * self.nx + this_coords[:, 3]
|
||||
indices = indices.type(torch.long)
|
||||
pillars = pillar_features[batch_mask, :]
|
||||
pillars = pillars.t()
|
||||
spatial_feature[:, indices] = pillars
|
||||
batch_spatial_features.append(spatial_feature)
|
||||
|
||||
batch_spatial_features = torch.stack(batch_spatial_features, 0)
|
||||
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx)
|
||||
batch_dict['spatial_features'] = batch_spatial_features
|
||||
return batch_dict
|
||||
|
||||
|
||||
class PointPillarScatter3d(nn.Module):
|
||||
def __init__(self, model_cfg, grid_size, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.model_cfg = model_cfg
|
||||
self.nx, self.ny, self.nz = self.model_cfg.INPUT_SHAPE
|
||||
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
|
||||
self.num_bev_features_before_compression = self.model_cfg.NUM_BEV_FEATURES // self.nz
|
||||
|
||||
def forward(self, batch_dict, **kwargs):
|
||||
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
|
||||
|
||||
batch_spatial_features = []
|
||||
batch_size = coords[:, 0].max().int().item() + 1
|
||||
for batch_idx in range(batch_size):
|
||||
spatial_feature = torch.zeros(
|
||||
self.num_bev_features_before_compression,
|
||||
self.nz * self.nx * self.ny,
|
||||
dtype=pillar_features.dtype,
|
||||
device=pillar_features.device)
|
||||
|
||||
batch_mask = coords[:, 0] == batch_idx
|
||||
this_coords = coords[batch_mask, :]
|
||||
indices = this_coords[:, 1] * self.ny * self.nx + this_coords[:, 2] * self.nx + this_coords[:, 3]
|
||||
indices = indices.type(torch.long)
|
||||
pillars = pillar_features[batch_mask, :]
|
||||
pillars = pillars.t()
|
||||
spatial_feature[:, indices] = pillars
|
||||
batch_spatial_features.append(spatial_feature)
|
||||
|
||||
batch_spatial_features = torch.stack(batch_spatial_features, 0)
|
||||
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx)
|
||||
batch_dict['spatial_features'] = batch_spatial_features
|
||||
return batch_dict
|
||||
Reference in New Issue
Block a user