diff --git a/pcdet/models/backbones_3d/vfe/mean_vfe.py b/pcdet/models/backbones_3d/vfe/mean_vfe.py new file mode 100644 index 0000000..42bd21f --- /dev/null +++ b/pcdet/models/backbones_3d/vfe/mean_vfe.py @@ -0,0 +1,31 @@ +import torch + +from .vfe_template import VFETemplate + + +class MeanVFE(VFETemplate): + def __init__(self, model_cfg, num_point_features, **kwargs): + super().__init__(model_cfg=model_cfg) + self.num_point_features = num_point_features + + def get_output_feature_dim(self): + return self.num_point_features + + def forward(self, batch_dict, **kwargs): + """ + Args: + batch_dict: + voxels: (num_voxels, max_points_per_voxel, C) + voxel_num_points: optional (num_voxels) + **kwargs: + + Returns: + vfe_features: (num_voxels, C) + """ + voxel_features, voxel_num_points = batch_dict['voxels'], batch_dict['voxel_num_points'] + points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) + normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) + points_mean = points_mean / normalizer + batch_dict['voxel_features'] = points_mean.contiguous() + + return batch_dict