ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
PointSampling.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 namespace cloudViewer {
11 namespace ml {
12 namespace contrib {
13 
14 static __device__ void __update(float *__restrict__ dists,
15  int *__restrict__ dists_i,
16  int idx1,
17  int idx2) {
18  const float v1 = dists[idx1], v2 = dists[idx2];
19  const int i1 = dists_i[idx1], i2 = dists_i[idx2];
20  dists[idx1] = max(v1, v2);
21  dists_i[idx1] = v2 > v1 ? i2 : i1;
22 }
23 
24 template <unsigned int block_size>
25 __global__ void furthest_point_sampling_kernel(
26  int b,
27  int n,
28  int m,
29  const float *__restrict__ dataset,
30  float *__restrict__ temp,
31  int *__restrict__ idxs) {
32  // dataset: (B, N, 3)
33  // tmp: (B, N)
34  // output:
35  // idx: (B, M)
36 
37  if (m <= 0) return;
38  __shared__ float dists[block_size];
39  __shared__ int dists_i[block_size];
40 
41  int batch_index = blockIdx.x;
42  dataset += batch_index * n * 3;
43  temp += batch_index * n;
44  idxs += batch_index * m;
45 
46  int tid = threadIdx.x;
47  const int stride = block_size;
48 
49  int old = 0;
50  if (threadIdx.x == 0) idxs[0] = old;
51 
52  __syncthreads();
53  for (int j = 1; j < m; j++) {
54  int besti = 0;
55  float best = -1;
56  float x1 = dataset[old * 3 + 0];
57  float y1 = dataset[old * 3 + 1];
58  float z1 = dataset[old * 3 + 2];
59  for (int k = tid; k < n; k += stride) {
60  float x2, y2, z2;
61  x2 = dataset[k * 3 + 0];
62  y2 = dataset[k * 3 + 1];
63  z2 = dataset[k * 3 + 2];
64  // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
65  // if (mag <= 1e-3)
66  // continue;
67 
68  float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
69  (z2 - z1) * (z2 - z1);
70  float d2 = min(d, temp[k]);
71  temp[k] = d2;
72  besti = d2 > best ? k : besti;
73  best = d2 > best ? d2 : best;
74  }
75  dists[tid] = best;
76  dists_i[tid] = besti;
77  __syncthreads();
78 
79  if (block_size >= 1024) {
80  if (tid < 512) {
81  __update(dists, dists_i, tid, tid + 512);
82  }
83  __syncthreads();
84  }
85 
86  if (block_size >= 512) {
87  if (tid < 256) {
88  __update(dists, dists_i, tid, tid + 256);
89  }
90  __syncthreads();
91  }
92  if (block_size >= 256) {
93  if (tid < 128) {
94  __update(dists, dists_i, tid, tid + 128);
95  }
96  __syncthreads();
97  }
98  if (block_size >= 128) {
99  if (tid < 64) {
100  __update(dists, dists_i, tid, tid + 64);
101  }
102  __syncthreads();
103  }
104  if (block_size >= 64) {
105  if (tid < 32) {
106  __update(dists, dists_i, tid, tid + 32);
107  }
108  __syncthreads();
109  }
110  if (block_size >= 32) {
111  if (tid < 16) {
112  __update(dists, dists_i, tid, tid + 16);
113  }
114  __syncthreads();
115  }
116  if (block_size >= 16) {
117  if (tid < 8) {
118  __update(dists, dists_i, tid, tid + 8);
119  }
120  __syncthreads();
121  }
122  if (block_size >= 8) {
123  if (tid < 4) {
124  __update(dists, dists_i, tid, tid + 4);
125  }
126  __syncthreads();
127  }
128  if (block_size >= 4) {
129  if (tid < 2) {
130  __update(dists, dists_i, tid, tid + 2);
131  }
132  __syncthreads();
133  }
134  if (block_size >= 2) {
135  if (tid < 1) {
136  __update(dists, dists_i, tid, tid + 1);
137  }
138  __syncthreads();
139  }
140 
141  old = dists_i[0];
142  if (tid == 0) idxs[j] = old;
143  }
144 }
145 
146 } // namespace contrib
147 } // namespace ml
148 } // namespace cloudViewer