Files
OpenPCDet/pcdet/ops/ingroup_inds/ingroup_inds_op.py
2025-09-21 20:19:25 +08:00

31 lines
632 B
Python

import torch
try:
from . import ingroup_inds_cuda
# import ingroup_indices
except ImportError:
ingroup_indices = None
print('Can not import ingroup indices')
ingroup_indices = ingroup_inds_cuda
from torch.autograd import Function
class IngroupIndicesFunction(Function):
@staticmethod
def forward(ctx, group_inds):
out_inds = torch.zeros_like(group_inds) - 1
ingroup_indices.forward(group_inds, out_inds)
ctx.mark_non_differentiable(out_inds)
return out_inds
@staticmethod
def backward(ctx, g):
return None
ingroup_inds = IngroupIndicesFunction.apply