1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "cloudViewer/core/CUDAUtils.h"
9 #include "core/Tensor.h"
10 #include "core/nns/FixedRadiusIndex.h"
11 #include "core/nns/FixedRadiusSearchImpl.cuh"
12 #include "core/nns/NeighborSearchAllocator.h"
13 #include "core/nns/NeighborSearchCommon.h"
15 namespace cloudViewer {
20 void BuildSpatialHashTableCUDA(const Tensor& points,
22 const Tensor& points_row_splits,
23 const Tensor& hash_table_splits,
24 Tensor& hash_table_index,
25 Tensor& hash_table_cell_splits) {
26 CUDAScopedDevice scoped_device(points.GetDevice());
27 const cudaStream_t stream = 0;
28 int texture_alignment = 512;
30 void* temp_ptr = nullptr;
33 impl::BuildSpatialHashTableCUDA(
34 stream, temp_ptr, temp_size, texture_alignment,
35 points.GetShape()[0], points.GetDataPtr<T>(), T(radius),
36 points_row_splits.GetShape()[0],
37 points_row_splits.GetDataPtr<int64_t>(),
38 hash_table_splits.GetDataPtr<uint32_t>(),
39 hash_table_cell_splits.GetShape()[0],
40 hash_table_cell_splits.GetDataPtr<uint32_t>(),
41 hash_table_index.GetDataPtr<uint32_t>());
43 Device device = points.GetDevice();
45 Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
46 temp_ptr = temp_tensor.GetDataPtr();
48 impl::BuildSpatialHashTableCUDA(
49 stream, temp_ptr, temp_size, texture_alignment,
50 points.GetShape()[0], points.GetDataPtr<T>(), T(radius),
51 points_row_splits.GetShape()[0],
52 points_row_splits.GetDataPtr<int64_t>(),
53 hash_table_splits.GetDataPtr<uint32_t>(),
54 hash_table_cell_splits.GetShape()[0],
55 hash_table_cell_splits.GetDataPtr<uint32_t>(),
56 hash_table_index.GetDataPtr<uint32_t>());
59 template <class T, class TIndex>
60 void FixedRadiusSearchCUDA(const Tensor& points,
61 const Tensor& queries,
63 const Tensor& points_row_splits,
64 const Tensor& queries_row_splits,
65 const Tensor& hash_table_splits,
66 const Tensor& hash_table_index,
67 const Tensor& hash_table_cell_splits,
69 const bool ignore_query_point,
70 const bool return_distances,
72 Tensor& neighbors_index,
73 Tensor& neighbors_row_splits,
74 Tensor& neighbors_distance) {
75 CUDAScopedDevice scoped_device(points.GetDevice());
76 const cudaStream_t stream = 0;
77 int texture_alignment = 512;
79 Device device = points.GetDevice();
80 Dtype dtype = points.GetDtype();
81 Dtype index_dtype = Dtype::FromType<TIndex>();
83 NeighborSearchAllocator<T, TIndex> output_allocator(device);
84 void* temp_ptr = nullptr;
87 impl::FixedRadiusSearchCUDA<T, TIndex>(
88 stream, temp_ptr, temp_size, texture_alignment,
89 neighbors_row_splits.GetDataPtr<int64_t>(), points.GetShape()[0],
90 points.GetDataPtr<T>(), queries.GetShape()[0],
91 queries.GetDataPtr<T>(), T(radius), points_row_splits.GetShape()[0],
92 points_row_splits.GetDataPtr<int64_t>(),
93 queries_row_splits.GetShape()[0],
94 queries_row_splits.GetDataPtr<int64_t>(),
95 hash_table_splits.GetDataPtr<uint32_t>(),
96 hash_table_cell_splits.GetShape()[0],
97 hash_table_cell_splits.GetDataPtr<uint32_t>(),
98 hash_table_index.GetDataPtr<uint32_t>(), metric, ignore_query_point,
99 return_distances, output_allocator);
102 Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
103 temp_ptr = temp_tensor.GetDataPtr();
105 impl::FixedRadiusSearchCUDA<T, TIndex>(
106 stream, temp_ptr, temp_size, texture_alignment,
107 neighbors_row_splits.GetDataPtr<int64_t>(), points.GetShape()[0],
108 points.GetDataPtr<T>(), queries.GetShape()[0],
109 queries.GetDataPtr<T>(), T(radius), points_row_splits.GetShape()[0],
110 points_row_splits.GetDataPtr<int64_t>(),
111 queries_row_splits.GetShape()[0],
112 queries_row_splits.GetDataPtr<int64_t>(),
113 hash_table_splits.GetDataPtr<uint32_t>(),
114 hash_table_cell_splits.GetShape()[0],
115 hash_table_cell_splits.GetDataPtr<uint32_t>(),
116 hash_table_index.GetDataPtr<uint32_t>(), metric, ignore_query_point,
117 return_distances, output_allocator);
119 Tensor indices_unsorted = output_allocator.NeighborsIndex();
120 Tensor distances_unsorted = output_allocator.NeighborsDistance();
123 neighbors_index = indices_unsorted;
124 neighbors_distance = distances_unsorted;
126 // Sort indices & distances.
130 int64_t num_indices = indices_unsorted.GetShape()[0];
131 int64_t num_segments = neighbors_row_splits.GetShape()[0] - 1;
132 Tensor indices_sorted =
133 Tensor::Empty({num_indices}, index_dtype, device);
134 Tensor distances_sorted = Tensor::Empty({num_indices}, dtype, device);
136 // Determine temp_size for sorting
137 impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices,
139 neighbors_row_splits.GetDataPtr<int64_t>(),
140 indices_unsorted.GetDataPtr<TIndex>(),
141 distances_unsorted.GetDataPtr<T>(),
142 indices_sorted.GetDataPtr<TIndex>(),
143 distances_sorted.GetDataPtr<T>());
145 temp_tensor = Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
146 temp_ptr = temp_tensor.GetDataPtr();
148 // Actually run the sorting.
149 impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices,
151 neighbors_row_splits.GetDataPtr<int64_t>(),
152 indices_unsorted.GetDataPtr<TIndex>(),
153 distances_unsorted.GetDataPtr<T>(),
154 indices_sorted.GetDataPtr<TIndex>(),
155 distances_sorted.GetDataPtr<T>());
156 neighbors_index = indices_sorted;
157 neighbors_distance = distances_sorted;
161 template <class T, class TIndex>
162 void HybridSearchCUDA(const Tensor& points,
163 const Tensor& queries,
166 const Tensor& points_row_splits,
167 const Tensor& queries_row_splits,
168 const Tensor& hash_table_splits,
169 const Tensor& hash_table_index,
170 const Tensor& hash_table_cell_splits,
172 Tensor& neighbors_index,
173 Tensor& neighbors_count,
174 Tensor& neighbors_distance) {
175 CUDAScopedDevice scoped_device(points.GetDevice());
176 const cudaStream_t stream = 0;
178 Device device = points.GetDevice();
180 NeighborSearchAllocator<T, TIndex> output_allocator(device);
182 impl::HybridSearchCUDA<T, TIndex>(
183 stream, points.GetShape()[0], points.GetDataPtr<T>(),
184 queries.GetShape()[0], queries.GetDataPtr<T>(), T(radius), max_knn,
185 points_row_splits.GetShape()[0],
186 points_row_splits.GetDataPtr<int64_t>(),
187 queries_row_splits.GetShape()[0],
188 queries_row_splits.GetDataPtr<int64_t>(),
189 hash_table_splits.GetDataPtr<uint32_t>(),
190 hash_table_cell_splits.GetShape()[0],
191 hash_table_cell_splits.GetDataPtr<uint32_t>(),
192 hash_table_index.GetDataPtr<uint32_t>(), metric, output_allocator);
194 neighbors_index = output_allocator.NeighborsIndex();
195 neighbors_distance = output_allocator.NeighborsDistance();
196 neighbors_count = output_allocator.NeighborsCount();
199 #define INSTANTIATE_BUILD(T) \
200 template void BuildSpatialHashTableCUDA<T>( \
201 const Tensor& points, double radius, \
202 const Tensor& points_row_splits, const Tensor& hash_table_splits, \
203 Tensor& hash_table_index, Tensor& hash_table_cell_splits);
205 #define INSTANTIATE_RADIUS(T, TIndex) \
206 template void FixedRadiusSearchCUDA<T, TIndex>( \
207 const Tensor& points, const Tensor& queries, double radius, \
208 const Tensor& points_row_splits, const Tensor& queries_row_splits, \
209 const Tensor& hash_table_splits, const Tensor& hash_table_index, \
210 const Tensor& hash_table_cell_splits, const Metric metric, \
211 const bool ignore_query_point, const bool return_distances, \
212 const bool sort, Tensor& neighbors_index, \
213 Tensor& neighbors_row_splits, Tensor& neighbors_distance);
215 #define INSTANTIATE_HYBRID(T, TIndex) \
216 template void HybridSearchCUDA<T, TIndex>( \
217 const Tensor& points, const Tensor& queries, double radius, \
218 int max_knn, const Tensor& points_row_splits, \
219 const Tensor& queries_row_splits, const Tensor& hash_table_splits, \
220 const Tensor& hash_table_index, \
221 const Tensor& hash_table_cell_splits, const Metric metric, \
222 Tensor& neighbors_index, Tensor& neighbors_count, \
223 Tensor& neighbors_distance);
225 INSTANTIATE_BUILD(float)
226 INSTANTIATE_BUILD(double)
228 INSTANTIATE_RADIUS(float, int32_t)
229 INSTANTIATE_RADIUS(float, int64_t)
230 INSTANTIATE_RADIUS(double, int32_t)
231 INSTANTIATE_RADIUS(double, int64_t)
233 INSTANTIATE_HYBRID(float, int32_t)
234 INSTANTIATE_HYBRID(float, int64_t)
235 INSTANTIATE_HYBRID(double, int32_t)
236 INSTANTIATE_HYBRID(double, int64_t)
240 } // namespace cloudViewer