Add File
This commit is contained in:
26
pcdet/models/backbones_2d/map_to_bev/height_compression.py
Normal file
26
pcdet/models/backbones_2d/map_to_bev/height_compression.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class HeightCompression(nn.Module):
|
||||||
|
def __init__(self, model_cfg, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_cfg = model_cfg
|
||||||
|
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
|
||||||
|
|
||||||
|
def forward(self, batch_dict):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
batch_dict:
|
||||||
|
encoded_spconv_tensor: sparse tensor
|
||||||
|
Returns:
|
||||||
|
batch_dict:
|
||||||
|
spatial_features:
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoded_spconv_tensor = batch_dict['encoded_spconv_tensor']
|
||||||
|
spatial_features = encoded_spconv_tensor.dense()
|
||||||
|
N, C, D, H, W = spatial_features.shape
|
||||||
|
spatial_features = spatial_features.view(N, C * D, H, W)
|
||||||
|
batch_dict['spatial_features'] = spatial_features
|
||||||
|
batch_dict['spatial_features_stride'] = batch_dict['encoded_spconv_tensor_stride']
|
||||||
|
return batch_dict
|
||||||
Reference in New Issue
Block a user