1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include "cloudViewer/core/CUDAUtils.h"
13 #include "core/nns/NeighborSearchCommon.h"
14 #include "cub/cub.cuh"
15 #include "utility/MiniVec.h"
17 namespace cloudViewer {
24 template <int METRIC = L2, class T, int NDIM>
25 inline __device__ T NeighborsDist(const utility::MiniVec<T, NDIM> &p1,
26 const utility::MiniVec<T, NDIM> &p2) {
29 utility::MiniVec<T, NDIM> d = (p1 - p2).abs();
30 dist = d[0] > d[1] ? d[0] : d[1];
31 for (int i = 2; i < NDIM; ++i) {
32 dist = dist > d[i] ? dist : d[i];
34 } else if (METRIC == L1) {
35 utility::MiniVec<T, NDIM> d = (p1 - p2).abs();
36 for (int i = 0; i < NDIM; ++i) {
40 utility::MiniVec<T, NDIM> d = p1 - p2;
47 inline __device__ void Swap(T *x, T *y) {
53 template <class T, class TIndex>
54 inline __device__ void Heapify(T *dist, TIndex *idx, int root, int k) {
55 int child = root * 2 + 1;
58 if (child + 1 < k && dist[child + 1] > dist[child]) {
61 if (dist[root] > dist[child]) {
64 Swap<T>(&dist[root], &dist[child]);
65 Swap<TIndex>(&idx[root], &idx[child]);
71 template <class T, class TIndex>
72 __device__ void HeapSort(T *dist, TIndex *idx, int k) {
74 for (i = k - 1; i > 0; i--) {
75 Swap<T>(&dist[0], &dist[i]);
76 Swap<TIndex>(&idx[0], &idx[i]);
77 Heapify<T, TIndex>(dist, idx, 0, i);
81 template <class T, class TIndex, int METRIC = L2, int NDIM>
82 __global__ void KnnQueryKernel(TIndex *__restrict__ indices_ptr,
83 T *__restrict__ distances_ptr,
85 const T *__restrict__ points,
87 const T *__restrict__ queries,
89 int query_idx = blockIdx.x * blockDim.x + threadIdx.x;
90 if (query_idx >= num_queries) return;
92 typedef utility::MiniVec<T, NDIM> Vec_t;
94 Vec_t query_pos(queries + NDIM * query_idx);
99 for (int i = 0; i < knn; i++) {
104 for (int i = 0; i < num_points; i++) {
105 Vec_t dataset_pos(points + NDIM * i);
106 T dist = NeighborsDist<METRIC>(query_pos, dataset_pos);
107 if (dist < best_dist[0]) {
110 Heapify(best_dist, best_idx, 0, knn);
113 HeapSort(best_dist, best_idx, knn);
114 for (int i = 0; i < knn; i++) {
115 indices_ptr[i + knn * query_idx] = best_idx[i];
116 distances_ptr[i + knn * query_idx] = best_dist[i];
121 template <class T, class TIndex, int NDIM>
122 void KnnQuery(const cudaStream_t &stream,
126 const T *const points,
128 const T *const queries,
130 // input: queries: (m, 3), points: (n, 3), idx: (m, knn)
131 const int BLOCKSIZE = 256;
132 dim3 block(BLOCKSIZE, 1, 1);
134 grid.x = utility::DivUp(num_queries, block.x);
137 KnnQueryKernel<T, TIndex, L2, NDIM><<<grid, block, 0, stream>>>(
138 indices_ptr, distances_ptr, num_points, points, num_queries,
146 } // namespace cloudViewer