Add File
This commit is contained in:
54
pcdet/ops/ingroup_inds/src/ingroup_inds.cpp
Normal file
54
pcdet/ops/ingroup_inds/src/ingroup_inds.cpp
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#include <assert.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <torch/serialize/tensor.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) \
|
||||||
|
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
|
||||||
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
|
||||||
|
void ingroup_inds_launcher(
|
||||||
|
const long *group_inds_data,
|
||||||
|
long *out_inds_data,
|
||||||
|
int N,
|
||||||
|
int max_group_id
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
void ingroup_inds_gpu(
|
||||||
|
at::Tensor group_inds,
|
||||||
|
at::Tensor out_inds
|
||||||
|
);
|
||||||
|
|
||||||
|
void ingroup_inds_gpu(
|
||||||
|
at::Tensor group_inds,
|
||||||
|
at::Tensor out_inds
|
||||||
|
) {
|
||||||
|
|
||||||
|
CHECK_INPUT(group_inds);
|
||||||
|
CHECK_INPUT(out_inds);
|
||||||
|
int N = group_inds.size(0);
|
||||||
|
int max_group_id = group_inds.max().item().toLong();
|
||||||
|
|
||||||
|
|
||||||
|
long *group_inds_data = group_inds.data_ptr<long>();
|
||||||
|
long *out_inds_data = out_inds.data_ptr<long>();
|
||||||
|
|
||||||
|
ingroup_inds_launcher(
|
||||||
|
group_inds_data,
|
||||||
|
out_inds_data,
|
||||||
|
N,
|
||||||
|
max_group_id
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &ingroup_inds_gpu, "cuda version of get_inner_win_inds of SST");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user