This commit is contained in:
2025-09-21 20:19:24 +08:00
parent 8834983715
commit 7c0bc04646

View File

@@ -0,0 +1,54 @@
#include <assert.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void ingroup_inds_launcher(
const long *group_inds_data,
long *out_inds_data,
int N,
int max_group_id
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
);
void ingroup_inds_gpu(
at::Tensor group_inds,
at::Tensor out_inds
) {
CHECK_INPUT(group_inds);
CHECK_INPUT(out_inds);
int N = group_inds.size(0);
int max_group_id = group_inds.max().item().toLong();
long *group_inds_data = group_inds.data_ptr<long>();
long *out_inds_data = out_inds.data_ptr<long>();
ingroup_inds_launcher(
group_inds_data,
out_inds_data,
N,
max_group_id
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &ingroup_inds_gpu, "cuda version of get_inner_win_inds of SST");
}