1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "cloudViewer/core/CUDAUtils.h"
11 #include "core/MemoryManager.h"
12 #include "core/Tensor.h"
13 #include "core/linalg/AddMM.h"
14 #include "core/nns/KnnIndex.h"
15 #include "core/nns/KnnSearchImpl.cuh"
16 #include "core/nns/NeighborSearchAllocator.h"
17 #include "core/nns/kernel/BlockSelect.cuh"
18 #include "core/nns/kernel/DistancesUtils.cuh"
19 #include "core/nns/kernel/L2Select.cuh"
21 namespace cloudViewer {
25 #define CALL_KNN_BRUTE_FORCE(NDIM) \
26 impl::KnnQuery<T, TIndex, NDIM>( \
27 stream, indices_ptr, distances_ptr, points.GetShape(0), \
28 points.GetDataPtr<T>(), queries.GetShape(0), \
29 queries.GetDataPtr<T>(), knn);
31 template <class T, class TIndex, class OUTPUT_ALLOCATOR>
32 void KnnSearchCUDABruteForce(const Tensor& points,
33 const Tensor& queries,
35 OUTPUT_ALLOCATOR& output_allocator,
36 Tensor& query_neighbors_row_splits) {
37 CUDAScopedDevice scoped_device(points.GetDevice());
38 const cudaStream_t stream = cuda::GetStream();
39 int num_points = points.GetShape(0);
40 int num_queries = queries.GetShape(0);
42 // Return if input points are empty.
43 if (num_points == 0 || num_queries == 0) {
44 query_neighbors_row_splits.Fill(0);
48 output_allocator.AllocIndices(&indices_ptr, 0);
49 output_allocator.AllocDistances(&distances_ptr, 0);
52 knn = num_points > knn ? knn : num_points;
54 // Allocate output tensors.
55 query_neighbors_row_splits.AsRvalue() =
56 Tensor::Arange(0, num_queries * knn, knn);
59 const size_t num_indices = knn * num_queries;
60 output_allocator.AllocIndices(&indices_ptr, num_indices);
61 output_allocator.AllocDistances(&distances_ptr, num_indices);
63 // Call kernel function.
64 switch (points.GetShape(1)) {
66 CALL_KNN_BRUTE_FORCE(1);
69 CALL_KNN_BRUTE_FORCE(2);
72 CALL_KNN_BRUTE_FORCE(3);
75 CALL_KNN_BRUTE_FORCE(4);
78 CALL_KNN_BRUTE_FORCE(5);
81 CALL_KNN_BRUTE_FORCE(6);
84 CALL_KNN_BRUTE_FORCE(7);
87 CALL_KNN_BRUTE_FORCE(8);
91 "KnnSearchCUDABruteForce only support data with dimension "
97 template <class T, class TIndex, class OUTPUT_ALLOCATOR>
98 void KnnSearchCUDAOptimized(const Tensor& points,
99 const Tensor& queries,
101 OUTPUT_ALLOCATOR& output_allocator,
102 Tensor& query_neighbors_row_splits) {
103 CUDAScopedDevice scoped_device(points.GetDevice());
104 int num_points = points.GetShape(0);
105 int num_queries = queries.GetShape(0);
106 int dim = points.GetShape(1);
107 Device device = points.GetDevice();
108 Dtype dtype = points.GetDtype();
109 Dtype index_dtype = Dtype::FromType<TIndex>();
111 // Return if input points are empty.
112 if (num_points == 0 || num_queries == 0) {
113 query_neighbors_row_splits.Fill(0);
117 output_allocator.AllocIndices(&indices_ptr, 0);
118 output_allocator.AllocDistances(&distances_ptr, 0);
121 knn = num_points > knn ? knn : num_points;
123 // Allocate output tensors.
124 query_neighbors_row_splits.AsRvalue() =
125 Tensor::Arange(0, num_queries * knn, knn);
128 const size_t num_indices = knn * num_queries;
130 output_allocator.AllocIndices(&indices_ptr, num_indices);
131 output_allocator.AllocDistances(&distances_ptr, num_indices);
133 // Calculate norms, |p|^2, |q|^2.
134 Tensor point_norms = points.Mul(points).Sum({1});
135 Tensor query_norms = queries.Mul(queries).Sum({1});
137 // Divide queries and points into chunks (rows and cols).
140 chooseTileSize(num_queries, num_points, dim, sizeof(T), tile_rows,
142 int num_cols = utility::DivUp(num_points, tile_cols);
144 // Allocate temporary memory space.
145 Tensor temp_distances =
146 Tensor::Empty({tile_rows, tile_cols}, dtype, device);
147 Tensor buf_distances =
148 Tensor::Empty({tile_rows, num_cols * knn}, dtype, device);
150 Tensor::Empty({tile_rows, num_cols * knn}, index_dtype, device);
153 for (int i = 0; i < num_queries; i += tile_rows) {
154 int num_queries_i = std::min(tile_rows, num_queries - i);
155 Tensor queries_i = queries.Slice(0, i, i + num_queries_i);
156 Tensor query_norms_i = query_norms.Slice(0, i, i + num_queries_i);
157 Tensor buf_distances_row_view =
158 buf_distances.Slice(0, 0, num_queries_i);
159 Tensor buf_indices_row_view = buf_indices.Slice(0, 0, num_queries_i);
161 CUDAScopedStream scoped_stream(CUDAScopedStream::CreateNewStream);
162 cudaStream_t cur_stream = cuda::GetStream();
163 for (int j = 0; j < num_points; j += tile_cols) {
164 int num_points_j = std::min(tile_cols, num_points - j);
165 int col_j = j / tile_cols;
166 Tensor points_j = points.Slice(0, j, j + num_points_j);
167 Tensor point_norms_j =
168 point_norms.Slice(0, j, j + num_points_j);
169 Tensor temp_distances_view =
170 temp_distances.Slice(0, 0, num_queries_i)
171 .Slice(1, 0, num_points_j);
172 Tensor buf_distances_col_view = buf_distances_row_view.Slice(
173 1, knn * col_j, (col_j + 1) * knn);
174 Tensor buf_indices_col_view = buf_indices_row_view.Slice(
175 1, knn * col_j, (col_j + 1) * knn);
178 AddMM(queries_i, points_j.T(), temp_distances_view, -2.0, 0.0);
179 // Topk selection & Add |p|^2, |q|^2 with fused kernel
180 if (tile_cols == num_points) {
181 Tensor out_indices_view =
182 output_allocator.NeighborsIndex_()
183 .View({num_queries, knn})
184 .Slice(0, i, i + num_queries_i);
185 Tensor out_distances_view =
186 output_allocator.NeighborsDistance_()
187 .View({num_queries, knn})
188 .Slice(0, i, i + num_queries_i);
189 runL2SelectMin<T, TIndex>(cur_stream, temp_distances_view,
190 point_norms_j, out_distances_view,
191 out_indices_view, knn, num_cols,
193 out_distances_view.Add_(
194 query_norms_i.View({num_queries_i, 1}));
196 runL2SelectMin<T, TIndex>(
197 cur_stream, temp_distances_view, point_norms_j,
198 buf_distances_col_view, buf_indices_col_view, knn,
199 num_cols, tile_cols);
200 buf_distances_col_view.Add_(
201 query_norms_i.View({num_queries_i, 1}));
204 // Write results to output tensor.
205 if (tile_cols != num_points) {
206 runIncrementIndex<TIndex>(cur_stream, buf_indices_row_view, knn,
209 cur_stream, buf_distances_row_view.GetDataPtr<T>(),
210 buf_indices_row_view.GetDataPtr<TIndex>(),
211 distances_ptr + knn * i, indices_ptr + knn * i, false,
212 knn, buf_distances_row_view.GetShape(1),
213 buf_distances_row_view.GetShape(0));
219 template <class T, class TIndex>
220 void KnnSearchCUDA(const Tensor& points,
221 const Tensor& points_row_splits,
222 const Tensor& queries,
223 const Tensor& queries_row_splits,
225 Tensor& neighbors_index,
226 Tensor& neighbors_row_splits,
227 Tensor& neighbors_distance) {
228 CUDAScopedDevice scoped_device(points.GetDevice());
229 int num_points = points.GetShape(0);
230 int num_queries = queries.GetShape(0);
231 Device device = points.GetDevice();
232 bool brute_force = points.GetShape(1) < 8 && knn <= 32;
234 const int batch_size = points_row_splits.GetShape(0) - 1;
236 std::vector<NeighborSearchAllocator<T, TIndex>> batch_output_allocators(
237 batch_size, NeighborSearchAllocator<T, TIndex>(device));
239 int64_t last_neighbors_count = 0;
240 for (int i = 0; i < batch_size; ++i) {
241 const Tensor points_i =
242 points.Slice(0, points_row_splits[i].Item<int64_t>(),
243 points_row_splits[i + 1].Item<int64_t>());
244 const Tensor queries_i =
245 queries.Slice(0, queries_row_splits[i].Item<int64_t>(),
246 queries_row_splits[i + 1].Item<int64_t>());
247 int64_t num_queries_i = queries_i.GetShape(0);
248 Tensor neighbors_row_splits_i = neighbors_row_splits.Slice(
249 0, queries_row_splits[i].Item<int64_t>(),
250 queries_row_splits[i + 1].Item<int64_t>());
251 int64_t* neighbors_row_splits_i_ptr =
252 neighbors_row_splits_i.GetDataPtr<int64_t>();
255 KnnSearchCUDABruteForce<T, TIndex>(points_i, queries_i, knn,
256 batch_output_allocators[i],
257 neighbors_row_splits_i);
259 KnnSearchCUDAOptimized<T, TIndex>(points_i, queries_i, knn,
260 batch_output_allocators[i],
261 neighbors_row_splits_i);
265 for (int j = 0; j <= num_queries_i; ++j) {
266 neighbors_row_splits_i_ptr[j] += last_neighbors_count;
269 last_neighbors_count = neighbors_row_splits_i_ptr[num_queries_i];
272 if (batch_size == 1) {
273 neighbors_index = batch_output_allocators[0].NeighborsIndex().View(
276 batch_output_allocators[0].NeighborsDistance().View(
282 NeighborSearchAllocator<T, TIndex> output_allocator(device);
283 int64_t neighbors_size = 0;
284 for (const auto& a : batch_output_allocators) {
285 neighbors_size += a.NeighborsIndex().GetShape(0);
287 TIndex* neighbors_index_ptr;
288 T* neighbors_distance_ptr;
289 output_allocator.AllocIndices(&neighbors_index_ptr, neighbors_size);
290 output_allocator.AllocDistances(&neighbors_distance_ptr, neighbors_size);
292 last_neighbors_count = 0;
293 for (int i = 0; i < batch_size; ++i) {
294 auto& a = batch_output_allocators[i];
295 int64_t offset = points_row_splits[i].Item<int64_t>();
296 int64_t num_neighbors_i = a.NeighborsIndex().GetShape(0);
297 if (num_neighbors_i) {
298 Tensor NeighborIndexAccumulated = a.NeighborsIndex().Add(offset);
299 MemoryManager::Memcpy(neighbors_index_ptr + last_neighbors_count,
300 device, a.IndicesPtr(), device,
301 sizeof(TIndex) * num_neighbors_i);
302 MemoryManager::Memcpy(neighbors_distance_ptr + last_neighbors_count,
303 device, a.DistancesPtr(), device,
304 sizeof(T) * num_neighbors_i);
305 last_neighbors_count += num_neighbors_i;
308 neighbors_index = output_allocator.NeighborsIndex();
309 neighbors_distance = output_allocator.NeighborsDistance();
312 #define INSTANTIATE(T, TIndex) \
313 template void KnnSearchCUDA<T, TIndex>( \
314 const Tensor& points, const Tensor& points_row_splits, \
315 const Tensor& queries, const Tensor& queries_row_splits, int knn, \
316 Tensor& neighbors_index, Tensor& neighbors_row_splits, \
317 Tensor& neighbors_distance);
319 INSTANTIATE(float, int32_t)
320 INSTANTIATE(float, int64_t)
321 INSTANTIATE(double, int32_t)
322 INSTANTIATE(double, int64_t)
326 } // namespace cloudViewer