ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
BallQuery.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "ml/contrib/BallQuery.cuh"
9 
10 namespace cloudViewer {
11 namespace ml {
12 namespace contrib {
13 
14 __global__ void ball_query_kernel(int b,
15  int n,
16  int m,
17  float radius,
18  int nsample,
19  const float *__restrict__ new_xyz,
20  const float *__restrict__ xyz,
21  int *__restrict__ idx) {
22  // new_xyz: (B, M, 3)
23  // xyz: (B, N, 3)
24  // output:
25  // idx: (B, M, nsample)
26  int bs_idx = blockIdx.y;
27  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
28  if (bs_idx >= b || pt_idx >= m) return;
29 
30  new_xyz += bs_idx * m * 3 + pt_idx * 3;
31  xyz += bs_idx * n * 3;
32  idx += bs_idx * m * nsample + pt_idx * nsample;
33 
34  float radius2 = radius * radius;
35  float new_x = new_xyz[0];
36  float new_y = new_xyz[1];
37  float new_z = new_xyz[2];
38 
39  int cnt = 0;
40  for (int k = 0; k < n; ++k) {
41  float x = xyz[k * 3 + 0];
42  float y = xyz[k * 3 + 1];
43  float z = xyz[k * 3 + 2];
44  float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
45  (new_z - z) * (new_z - z);
46  if (d2 < radius2) {
47  if (cnt == 0) {
48  for (int l = 0; l < nsample; ++l) {
49  idx[l] = k;
50  }
51  }
52  idx[cnt] = k;
53  ++cnt;
54  if (cnt >= nsample) break;
55  }
56  }
57 }
58 
59 } // namespace contrib
60 } // namespace ml
61 } // namespace cloudViewer