Add File
This commit is contained in:
113
pcdet/ops/pointnet2/pointnet2_stack/src/voxel_query_gpu.cu
Normal file
113
pcdet/ops/pointnet2/pointnet2_stack/src/voxel_query_gpu.cu
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "voxel_query_gpu.h"
|
||||
#include "cuda_utils.h"
|
||||
|
||||
|
||||
__global__ void voxel_query_kernel_stack(int M, int R1, int R2, int R3, int nsample,
|
||||
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
|
||||
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
|
||||
// :param new_coords: (M1 + M2 ..., 4) centers of the ball query
|
||||
// :param point_indices: (B, Z, Y, X)
|
||||
// output:
|
||||
// idx: (M1 + M2, nsample)
|
||||
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (pt_idx >= M) return;
|
||||
|
||||
new_xyz += pt_idx * 3;
|
||||
new_coords += pt_idx * 4;
|
||||
idx += pt_idx * nsample;
|
||||
|
||||
curandState state;
|
||||
curand_init(pt_idx, 0, 0, &state);
|
||||
|
||||
float radius2 = radius * radius;
|
||||
float new_x = new_xyz[0];
|
||||
float new_y = new_xyz[1];
|
||||
float new_z = new_xyz[2];
|
||||
|
||||
int batch_idx = new_coords[0];
|
||||
int new_coords_z = new_coords[1];
|
||||
int new_coords_y = new_coords[2];
|
||||
int new_coords_x = new_coords[3];
|
||||
|
||||
int cnt = 0;
|
||||
int cnt2 = 0;
|
||||
// for (int dz = -1*z_range; dz <= z_range; ++dz) {
|
||||
for (int dz = -1*z_range; dz <= z_range; ++dz) {
|
||||
int z_coord = new_coords_z + dz;
|
||||
if (z_coord < 0 || z_coord >= R1) continue;
|
||||
|
||||
for (int dy = -1*y_range; dy <= y_range; ++dy) {
|
||||
int y_coord = new_coords_y + dy;
|
||||
if (y_coord < 0 || y_coord >= R2) continue;
|
||||
|
||||
for (int dx = -1*x_range; dx <= x_range; ++dx) {
|
||||
int x_coord = new_coords_x + dx;
|
||||
if (x_coord < 0 || x_coord >= R3) continue;
|
||||
|
||||
int index = batch_idx * R1 * R2 * R3 + \
|
||||
z_coord * R2 * R3 + \
|
||||
y_coord * R3 + \
|
||||
x_coord;
|
||||
int neighbor_idx = point_indices[index];
|
||||
if (neighbor_idx < 0) continue;
|
||||
|
||||
float x_per = xyz[neighbor_idx*3 + 0];
|
||||
float y_per = xyz[neighbor_idx*3 + 1];
|
||||
float z_per = xyz[neighbor_idx*3 + 2];
|
||||
|
||||
float dist2 = (x_per - new_x) * (x_per - new_x) + (y_per - new_y) * (y_per - new_y) + (z_per - new_z) * (z_per - new_z);
|
||||
|
||||
if (dist2 > radius2) continue;
|
||||
|
||||
++cnt2;
|
||||
|
||||
if (cnt < nsample) {
|
||||
if (cnt == 0) {
|
||||
for (int l = 0; l < nsample; ++l) {
|
||||
idx[l] = neighbor_idx;
|
||||
}
|
||||
}
|
||||
idx[cnt] = neighbor_idx;
|
||||
++cnt;
|
||||
}
|
||||
// else {
|
||||
// float rnd = curand_uniform(&state);
|
||||
// if (rnd < (float(nsample) / cnt2)) {
|
||||
// int insertidx = ceilf(curand_uniform(&state) * nsample) - 1;
|
||||
// idx[insertidx] = neighbor_idx;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cnt == 0) idx[0] = -1;
|
||||
}
|
||||
|
||||
|
||||
void voxel_query_kernel_launcher_stack(int M, int R1, int R2, int R3, int nsample,
|
||||
float radius, int z_range, int y_range, int x_range, const float *new_xyz,
|
||||
const float *xyz, const int *new_coords, const int *point_indices, int *idx) {
|
||||
// :param new_coords: (M1 + M2 ..., 4) centers of the voxel query
|
||||
// :param point_indices: (B, Z, Y, X)
|
||||
// output:
|
||||
// idx: (M1 + M2, nsample)
|
||||
|
||||
cudaError_t err;
|
||||
|
||||
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
voxel_query_kernel_stack<<<blocks, threads>>>(M, R1, R2, R3, nsample, radius, z_range, y_range, x_range, new_xyz, xyz, new_coords, point_indices, idx);
|
||||
// cudaDeviceSynchronize(); // for using printf in kernel function
|
||||
|
||||
err = cudaGetLastError();
|
||||
if (cudaSuccess != err) {
|
||||
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user