Add File
This commit is contained in:
54
pcdet/ops/ingroup_inds/src/ingroup_inds.cpp
Normal file
54
pcdet/ops/ingroup_inds/src/ingroup_inds.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
#include <assert.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/serialize/tensor.h>
|
||||
#include <vector>
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
void ingroup_inds_launcher(
|
||||
const long *group_inds_data,
|
||||
long *out_inds_data,
|
||||
int N,
|
||||
int max_group_id
|
||||
);
|
||||
|
||||
|
||||
void ingroup_inds_gpu(
|
||||
at::Tensor group_inds,
|
||||
at::Tensor out_inds
|
||||
);
|
||||
|
||||
void ingroup_inds_gpu(
|
||||
at::Tensor group_inds,
|
||||
at::Tensor out_inds
|
||||
) {
|
||||
|
||||
CHECK_INPUT(group_inds);
|
||||
CHECK_INPUT(out_inds);
|
||||
int N = group_inds.size(0);
|
||||
int max_group_id = group_inds.max().item().toLong();
|
||||
|
||||
|
||||
long *group_inds_data = group_inds.data_ptr<long>();
|
||||
long *out_inds_data = out_inds.data_ptr<long>();
|
||||
|
||||
ingroup_inds_launcher(
|
||||
group_inds_data,
|
||||
out_inds_data,
|
||||
N,
|
||||
max_group_id
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &ingroup_inds_gpu, "cuda version of get_inner_win_inds of SST");
|
||||
}
|
||||
Reference in New Issue
Block a user