1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "ml/contrib/BallQuery.cuh"
10 namespace cloudViewer {
14 __global__ void ball_query_kernel(int b,
19 const float *__restrict__ new_xyz,
20 const float *__restrict__ xyz,
21 int *__restrict__ idx) {
25 // idx: (B, M, nsample)
26 int bs_idx = blockIdx.y;
27 int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
28 if (bs_idx >= b || pt_idx >= m) return;
30 new_xyz += bs_idx * m * 3 + pt_idx * 3;
31 xyz += bs_idx * n * 3;
32 idx += bs_idx * m * nsample + pt_idx * nsample;
34 float radius2 = radius * radius;
35 float new_x = new_xyz[0];
36 float new_y = new_xyz[1];
37 float new_z = new_xyz[2];
40 for (int k = 0; k < n; ++k) {
41 float x = xyz[k * 3 + 0];
42 float y = xyz[k * 3 + 1];
43 float z = xyz[k * 3 + 2];
44 float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
45 (new_z - z) * (new_z - z);
48 for (int l = 0; l < nsample; ++l) {
54 if (cnt >= nsample) break;
59 } // namespace contrib
61 } // namespace cloudViewer