1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 namespace cloudViewer {
14 static __device__ void __update(float *__restrict__ dists,
15 int *__restrict__ dists_i,
18 const float v1 = dists[idx1], v2 = dists[idx2];
19 const int i1 = dists_i[idx1], i2 = dists_i[idx2];
20 dists[idx1] = max(v1, v2);
21 dists_i[idx1] = v2 > v1 ? i2 : i1;
24 template <unsigned int block_size>
25 __global__ void furthest_point_sampling_kernel(
29 const float *__restrict__ dataset,
30 float *__restrict__ temp,
31 int *__restrict__ idxs) {
38 __shared__ float dists[block_size];
39 __shared__ int dists_i[block_size];
41 int batch_index = blockIdx.x;
42 dataset += batch_index * n * 3;
43 temp += batch_index * n;
44 idxs += batch_index * m;
46 int tid = threadIdx.x;
47 const int stride = block_size;
50 if (threadIdx.x == 0) idxs[0] = old;
53 for (int j = 1; j < m; j++) {
56 float x1 = dataset[old * 3 + 0];
57 float y1 = dataset[old * 3 + 1];
58 float z1 = dataset[old * 3 + 2];
59 for (int k = tid; k < n; k += stride) {
61 x2 = dataset[k * 3 + 0];
62 y2 = dataset[k * 3 + 1];
63 z2 = dataset[k * 3 + 2];
64 // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
68 float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
69 (z2 - z1) * (z2 - z1);
70 float d2 = min(d, temp[k]);
72 besti = d2 > best ? k : besti;
73 best = d2 > best ? d2 : best;
79 if (block_size >= 1024) {
81 __update(dists, dists_i, tid, tid + 512);
86 if (block_size >= 512) {
88 __update(dists, dists_i, tid, tid + 256);
92 if (block_size >= 256) {
94 __update(dists, dists_i, tid, tid + 128);
98 if (block_size >= 128) {
100 __update(dists, dists_i, tid, tid + 64);
104 if (block_size >= 64) {
106 __update(dists, dists_i, tid, tid + 32);
110 if (block_size >= 32) {
112 __update(dists, dists_i, tid, tid + 16);
116 if (block_size >= 16) {
118 __update(dists, dists_i, tid, tid + 8);
122 if (block_size >= 8) {
124 __update(dists, dists_i, tid, tid + 4);
128 if (block_size >= 4) {
130 __update(dists, dists_i, tid, tid + 2);
134 if (block_size >= 2) {
136 __update(dists, dists_i, tid, tid + 1);
142 if (tid == 0) idxs[j] = old;
146 } // namespace contrib
148 } // namespace cloudViewer