ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
KnnSearchImpl.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <Helper.h>
11 
12 #include "cloudViewer/core/CUDAUtils.h"
13 #include "core/nns/NeighborSearchCommon.h"
14 #include "cub/cub.cuh"
15 #include "utility/MiniVec.h"
16 
17 namespace cloudViewer {
18 namespace core {
19 namespace nns {
20 namespace impl {
21 
22 namespace {
23 
24 template <int METRIC = L2, class T, int NDIM>
25 inline __device__ T NeighborsDist(const utility::MiniVec<T, NDIM> &p1,
26  const utility::MiniVec<T, NDIM> &p2) {
27  T dist;
28  if (METRIC == Linf) {
29  utility::MiniVec<T, NDIM> d = (p1 - p2).abs();
30  dist = d[0] > d[1] ? d[0] : d[1];
31  for (int i = 2; i < NDIM; ++i) {
32  dist = dist > d[i] ? dist : d[i];
33  }
34  } else if (METRIC == L1) {
35  utility::MiniVec<T, NDIM> d = (p1 - p2).abs();
36  for (int i = 0; i < NDIM; ++i) {
37  dist += d[i];
38  }
39  } else {
40  utility::MiniVec<T, NDIM> d = p1 - p2;
41  dist = d.dot(d);
42  }
43  return dist;
44 }
45 
46 template <class T>
47 inline __device__ void Swap(T *x, T *y) {
48  T tmp = *x;
49  *x = *y;
50  *y = tmp;
51 }
52 
53 template <class T, class TIndex>
54 inline __device__ void Heapify(T *dist, TIndex *idx, int root, int k) {
55  int child = root * 2 + 1;
56 
57  while (child < k) {
58  if (child + 1 < k && dist[child + 1] > dist[child]) {
59  child++;
60  }
61  if (dist[root] > dist[child]) {
62  return;
63  }
64  Swap<T>(&dist[root], &dist[child]);
65  Swap<TIndex>(&idx[root], &idx[child]);
66  root = child;
67  child = root * 2 + 1;
68  }
69 }
70 
71 template <class T, class TIndex>
72 __device__ void HeapSort(T *dist, TIndex *idx, int k) {
73  int i;
74  for (i = k - 1; i > 0; i--) {
75  Swap<T>(&dist[0], &dist[i]);
76  Swap<TIndex>(&idx[0], &idx[i]);
77  Heapify<T, TIndex>(dist, idx, 0, i);
78  }
79 }
80 
81 template <class T, class TIndex, int METRIC = L2, int NDIM>
82 __global__ void KnnQueryKernel(TIndex *__restrict__ indices_ptr,
83  T *__restrict__ distances_ptr,
84  size_t num_points,
85  const T *__restrict__ points,
86  size_t num_queries,
87  const T *__restrict__ queries,
88  int knn) {
89  int query_idx = blockIdx.x * blockDim.x + threadIdx.x;
90  if (query_idx >= num_queries) return;
91 
92  typedef utility::MiniVec<T, NDIM> Vec_t;
93 
94  Vec_t query_pos(queries + NDIM * query_idx);
95 
96  T best_dist[100];
97  int best_idx[100];
98 
99  for (int i = 0; i < knn; i++) {
100  best_dist[i] = 1e10;
101  best_idx[i] = 0;
102  }
103 
104  for (int i = 0; i < num_points; i++) {
105  Vec_t dataset_pos(points + NDIM * i);
106  T dist = NeighborsDist<METRIC>(query_pos, dataset_pos);
107  if (dist < best_dist[0]) {
108  best_dist[0] = dist;
109  best_idx[0] = i;
110  Heapify(best_dist, best_idx, 0, knn);
111  }
112  }
113  HeapSort(best_dist, best_idx, knn);
114  for (int i = 0; i < knn; i++) {
115  indices_ptr[i + knn * query_idx] = best_idx[i];
116  distances_ptr[i + knn * query_idx] = best_dist[i];
117  }
118 }
119 } // namespace
120 
121 template <class T, class TIndex, int NDIM>
122 void KnnQuery(const cudaStream_t &stream,
123  TIndex *indices_ptr,
124  T *distances_ptr,
125  size_t num_points,
126  const T *const points,
127  size_t num_queries,
128  const T *const queries,
129  int knn) {
130  // input: queries: (m, 3), points: (n, 3), idx: (m, knn)
131  const int BLOCKSIZE = 256;
132  dim3 block(BLOCKSIZE, 1, 1);
133  dim3 grid(0, 1, 1);
134  grid.x = utility::DivUp(num_queries, block.x);
135 
136  if (grid.x) {
137  KnnQueryKernel<T, TIndex, L2, NDIM><<<grid, block, 0, stream>>>(
138  indices_ptr, distances_ptr, num_points, points, num_queries,
139  queries, knn);
140  }
141 }
142 
143 } // namespace impl
144 } // namespace nns
145 } // namespace core
146 } // namespace cloudViewer