From 82563338daa3a633ac5c6d58b81841d790841723 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:18:42 +0800 Subject: [PATCH] Add File --- pcdet/models/view_transforms/depth_lss.py | 258 ++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 pcdet/models/view_transforms/depth_lss.py diff --git a/pcdet/models/view_transforms/depth_lss.py b/pcdet/models/view_transforms/depth_lss.py new file mode 100644 index 0000000..8fee406 --- /dev/null +++ b/pcdet/models/view_transforms/depth_lss.py @@ -0,0 +1,258 @@ +import torch +from torch import nn +from pcdet.ops.bev_pool import bev_pool + + +def gen_dx_bx(xbound, ybound, zbound): + dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]]) + bx = torch.Tensor([row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]]) + nx = torch.LongTensor( + [(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]] + ) + return dx, bx, nx + + +class DepthLSSTransform(nn.Module): + """ + This module implements LSS, which lists images into 3D and then splats onto bev features. + This code is adapted from https://github.com/mit-han-lab/bevfusion/ with minimal modifications. + """ + def __init__(self, model_cfg): + super().__init__() + self.model_cfg = model_cfg + in_channel = self.model_cfg.IN_CHANNEL + out_channel = self.model_cfg.OUT_CHANNEL + self.image_size = self.model_cfg.IMAGE_SIZE + self.feature_size = self.model_cfg.FEATURE_SIZE + xbound = self.model_cfg.XBOUND + ybound = self.model_cfg.YBOUND + zbound = self.model_cfg.ZBOUND + self.dbound = self.model_cfg.DBOUND + downsample = self.model_cfg.DOWNSAMPLE + + dx, bx, nx = gen_dx_bx(xbound, ybound, zbound) + self.dx = nn.Parameter(dx, requires_grad=False) + self.bx = nn.Parameter(bx, requires_grad=False) + self.nx = nn.Parameter(nx, requires_grad=False) + + self.C = out_channel + self.frustum = self.create_frustum() + self.D = self.frustum.shape[0] + + self.dtransform = nn.Sequential( + nn.Conv2d(1, 8, 1), + nn.BatchNorm2d(8), + nn.ReLU(True), + nn.Conv2d(8, 32, 5, stride=4, padding=2), + nn.BatchNorm2d(32), + nn.ReLU(True), + nn.Conv2d(32, 64, 5, stride=2, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True), + ) + self.depthnet = nn.Sequential( + nn.Conv2d(in_channel + 64, in_channel, 3, padding=1), + nn.BatchNorm2d(in_channel), + nn.ReLU(True), + nn.Conv2d(in_channel, in_channel, 3, padding=1), + nn.BatchNorm2d(in_channel), + nn.ReLU(True), + nn.Conv2d(in_channel, self.D + self.C, 1), + ) + if downsample > 1: + assert downsample == 2, downsample + self.downsample = nn.Sequential( + nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False), + nn.BatchNorm2d(out_channel), + nn.ReLU(True), + nn.Conv2d(out_channel, out_channel, 3, stride=downsample, padding=1, bias=False), + nn.BatchNorm2d(out_channel), + nn.ReLU(True), + nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=False), + nn.BatchNorm2d(out_channel), + nn.ReLU(True), + ) + else: + self.downsample = nn.Identity() + + def create_frustum(self): + iH, iW = self.image_size + fH, fW = self.feature_size + + ds = torch.arange(*self.dbound, dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW) + D, _, _ = ds.shape + xs = torch.linspace(0, iW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW) + ys = torch.linspace(0, iH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW) + frustum = torch.stack((xs, ys, ds), -1) + + return nn.Parameter(frustum, requires_grad=False) + + def get_geometry(self, camera2lidar_rots, camera2lidar_trans, intrins, post_rots, post_trans, **kwargs): + + camera2lidar_rots = camera2lidar_rots.to(torch.float) + camera2lidar_trans = camera2lidar_trans.to(torch.float) + intrins = intrins.to(torch.float) + post_rots = post_rots.to(torch.float) + post_trans = post_trans.to(torch.float) + + B, N, _ = camera2lidar_trans.shape + + # undo post-transformation + # B x N x D x H x W x 3 + points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3) + points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1)) + + # cam_to_lidar + points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5) + combine = camera2lidar_rots.matmul(torch.inverse(intrins)) + points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) + points += camera2lidar_trans.view(B, N, 1, 1, 1, 3) + + if "extra_rots" in kwargs: + extra_rots = kwargs["extra_rots"] + points = extra_rots.view(B, 1, 1, 1, 1, 3, 3).repeat(1, N, 1, 1, 1, 1, 1) \ + .matmul(points.unsqueeze(-1)).squeeze(-1) + + if "extra_trans" in kwargs: + extra_trans = kwargs["extra_trans"] + points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1) + + return points + + def bev_pool(self, geom_feats, x): + geom_feats = geom_feats.to(torch.float) + x = x.to(torch.float) + + B, N, D, H, W, C = x.shape + Nprime = B * N * D * H * W + + # flatten x + x = x.reshape(Nprime, C) + + # flatten indices + geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long() + geom_feats = geom_feats.view(Nprime, 3) + batch_ix = torch.cat([torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long) for ix in range(B)]) + geom_feats = torch.cat((geom_feats, batch_ix), 1) + + # filter out points that are outside box + kept = ( + (geom_feats[:, 0] >= 0) + & (geom_feats[:, 0] < self.nx[0]) + & (geom_feats[:, 1] >= 0) + & (geom_feats[:, 1] < self.nx[1]) + & (geom_feats[:, 2] >= 0) + & (geom_feats[:, 2] < self.nx[2]) + ) + x = x[kept] + geom_feats = geom_feats[kept] + x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1]) + + # collapse Z + final = torch.cat(x.unbind(dim=2), 1) + + return final + + def get_cam_feats(self, x, d): + B, N, C, fH, fW = x.shape + + d = d.view(B * N, *d.shape[2:]) + x = x.view(B * N, C, fH, fW) + + d = self.dtransform(d) + x = torch.cat([d, x], dim=1) + x = self.depthnet(x) + + depth = x[:, : self.D].softmax(dim=1) + x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2) + + x = x.view(B, N, self.C, self.D, fH, fW) + x = x.permute(0, 1, 3, 4, 5, 2) + return x + + def forward(self, batch_dict): + """ + Args: + batch_dict: + image_fpn (list[tensor]): image features after image neck + + Returns: + batch_dict: + spatial_features_img (tensor): bev features from image modality + """ + x = batch_dict['image_fpn'] + x = x[0] + BN, C, H, W = x.size() + img = x.view(int(BN/6), 6, C, H, W) + + camera_intrinsics = batch_dict['camera_intrinsics'] + camera2lidar = batch_dict['camera2lidar'] + img_aug_matrix = batch_dict['img_aug_matrix'] + lidar_aug_matrix = batch_dict['lidar_aug_matrix'] + lidar2image = batch_dict['lidar2image'] + + intrins = camera_intrinsics[..., :3, :3] + post_rots = img_aug_matrix[..., :3, :3] + post_trans = img_aug_matrix[..., :3, 3] + camera2lidar_rots = camera2lidar[..., :3, :3] + camera2lidar_trans = camera2lidar[..., :3, 3] + + points = batch_dict['points'] + + batch_size = BN // 6 + depth = torch.zeros(batch_size, img.shape[1], 1, *self.image_size).to(points[0].device) + + for b in range(batch_size): + batch_mask = points[:,0] == b + cur_coords = points[batch_mask][:, 1:4] + cur_img_aug_matrix = img_aug_matrix[b] + cur_lidar_aug_matrix = lidar_aug_matrix[b] + cur_lidar2image = lidar2image[b] + + # inverse aug + cur_coords -= cur_lidar_aug_matrix[:3, 3] + cur_coords = torch.inverse(cur_lidar_aug_matrix[:3, :3]).matmul( + cur_coords.transpose(1, 0) + ) + # lidar2image + cur_coords = cur_lidar2image[:, :3, :3].matmul(cur_coords) + cur_coords += cur_lidar2image[:, :3, 3].reshape(-1, 3, 1) + # get 2d coords + dist = cur_coords[:, 2, :] + cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5) + cur_coords[:, :2, :] /= cur_coords[:, 2:3, :] + + # do image aug + cur_coords = cur_img_aug_matrix[:, :3, :3].matmul(cur_coords) + cur_coords += cur_img_aug_matrix[:, :3, 3].reshape(-1, 3, 1) + cur_coords = cur_coords[:, :2, :].transpose(1, 2) + + # normalize coords for grid sample + cur_coords = cur_coords[..., [1, 0]] + + # filter points outside of images + on_img = ( + (cur_coords[..., 0] < self.image_size[0]) + & (cur_coords[..., 0] >= 0) + & (cur_coords[..., 1] < self.image_size[1]) + & (cur_coords[..., 1] >= 0) + ) + for c in range(on_img.shape[0]): + masked_coords = cur_coords[c, on_img[c]].long() + masked_dist = dist[c, on_img[c]] + depth[b, c, 0, masked_coords[:, 0], masked_coords[:, 1]] = masked_dist + + extra_rots = lidar_aug_matrix[..., :3, :3] + extra_trans = lidar_aug_matrix[..., :3, 3] + geom = self.get_geometry( + camera2lidar_rots, camera2lidar_trans, intrins, post_rots, + post_trans, extra_rots=extra_rots, extra_trans=extra_trans, + ) + # use points depth to assist the depth prediction in images + x = self.get_cam_feats(img, depth) + x = self.bev_pool(geom, x) + x = self.downsample(x) + # convert bev features from (b, c, x, y) to (b, c, y, x) + x = x.permute(0, 1, 3, 2) + batch_dict['spatial_features_img'] = x + return batch_dict \ No newline at end of file