ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
FixedRadiusSearchOps.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "cloudViewer/core/CUDAUtils.h"
9 #include "core/Tensor.h"
10 #include "core/nns/FixedRadiusIndex.h"
11 #include "core/nns/FixedRadiusSearchImpl.cuh"
12 #include "core/nns/NeighborSearchAllocator.h"
13 #include "core/nns/NeighborSearchCommon.h"
14 
15 namespace cloudViewer {
16 namespace core {
17 namespace nns {
18 
19 template <class T>
20 void BuildSpatialHashTableCUDA(const Tensor& points,
21  double radius,
22  const Tensor& points_row_splits,
23  const Tensor& hash_table_splits,
24  Tensor& hash_table_index,
25  Tensor& hash_table_cell_splits) {
26  CUDAScopedDevice scoped_device(points.GetDevice());
27  const cudaStream_t stream = 0;
28  int texture_alignment = 512;
29 
30  void* temp_ptr = nullptr;
31  size_t temp_size = 0;
32 
33  impl::BuildSpatialHashTableCUDA(
34  stream, temp_ptr, temp_size, texture_alignment,
35  points.GetShape()[0], points.GetDataPtr<T>(), T(radius),
36  points_row_splits.GetShape()[0],
37  points_row_splits.GetDataPtr<int64_t>(),
38  hash_table_splits.GetDataPtr<uint32_t>(),
39  hash_table_cell_splits.GetShape()[0],
40  hash_table_cell_splits.GetDataPtr<uint32_t>(),
41  hash_table_index.GetDataPtr<uint32_t>());
42 
43  Device device = points.GetDevice();
44  Tensor temp_tensor =
45  Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
46  temp_ptr = temp_tensor.GetDataPtr();
47 
48  impl::BuildSpatialHashTableCUDA(
49  stream, temp_ptr, temp_size, texture_alignment,
50  points.GetShape()[0], points.GetDataPtr<T>(), T(radius),
51  points_row_splits.GetShape()[0],
52  points_row_splits.GetDataPtr<int64_t>(),
53  hash_table_splits.GetDataPtr<uint32_t>(),
54  hash_table_cell_splits.GetShape()[0],
55  hash_table_cell_splits.GetDataPtr<uint32_t>(),
56  hash_table_index.GetDataPtr<uint32_t>());
57 }
58 
59 template <class T, class TIndex>
60 void FixedRadiusSearchCUDA(const Tensor& points,
61  const Tensor& queries,
62  double radius,
63  const Tensor& points_row_splits,
64  const Tensor& queries_row_splits,
65  const Tensor& hash_table_splits,
66  const Tensor& hash_table_index,
67  const Tensor& hash_table_cell_splits,
68  const Metric metric,
69  const bool ignore_query_point,
70  const bool return_distances,
71  const bool sort,
72  Tensor& neighbors_index,
73  Tensor& neighbors_row_splits,
74  Tensor& neighbors_distance) {
75  CUDAScopedDevice scoped_device(points.GetDevice());
76  const cudaStream_t stream = 0;
77  int texture_alignment = 512;
78 
79  Device device = points.GetDevice();
80  Dtype dtype = points.GetDtype();
81  Dtype index_dtype = Dtype::FromType<TIndex>();
82 
83  NeighborSearchAllocator<T, TIndex> output_allocator(device);
84  void* temp_ptr = nullptr;
85  size_t temp_size = 0;
86 
87  impl::FixedRadiusSearchCUDA<T, TIndex>(
88  stream, temp_ptr, temp_size, texture_alignment,
89  neighbors_row_splits.GetDataPtr<int64_t>(), points.GetShape()[0],
90  points.GetDataPtr<T>(), queries.GetShape()[0],
91  queries.GetDataPtr<T>(), T(radius), points_row_splits.GetShape()[0],
92  points_row_splits.GetDataPtr<int64_t>(),
93  queries_row_splits.GetShape()[0],
94  queries_row_splits.GetDataPtr<int64_t>(),
95  hash_table_splits.GetDataPtr<uint32_t>(),
96  hash_table_cell_splits.GetShape()[0],
97  hash_table_cell_splits.GetDataPtr<uint32_t>(),
98  hash_table_index.GetDataPtr<uint32_t>(), metric, ignore_query_point,
99  return_distances, output_allocator);
100 
101  Tensor temp_tensor =
102  Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
103  temp_ptr = temp_tensor.GetDataPtr();
104 
105  impl::FixedRadiusSearchCUDA<T, TIndex>(
106  stream, temp_ptr, temp_size, texture_alignment,
107  neighbors_row_splits.GetDataPtr<int64_t>(), points.GetShape()[0],
108  points.GetDataPtr<T>(), queries.GetShape()[0],
109  queries.GetDataPtr<T>(), T(radius), points_row_splits.GetShape()[0],
110  points_row_splits.GetDataPtr<int64_t>(),
111  queries_row_splits.GetShape()[0],
112  queries_row_splits.GetDataPtr<int64_t>(),
113  hash_table_splits.GetDataPtr<uint32_t>(),
114  hash_table_cell_splits.GetShape()[0],
115  hash_table_cell_splits.GetDataPtr<uint32_t>(),
116  hash_table_index.GetDataPtr<uint32_t>(), metric, ignore_query_point,
117  return_distances, output_allocator);
118 
119  Tensor indices_unsorted = output_allocator.NeighborsIndex();
120  Tensor distances_unsorted = output_allocator.NeighborsDistance();
121 
122  if (!sort) {
123  neighbors_index = indices_unsorted;
124  neighbors_distance = distances_unsorted;
125  } else {
126  // Sort indices & distances.
127  temp_ptr = nullptr;
128  temp_size = 0;
129 
130  int64_t num_indices = indices_unsorted.GetShape()[0];
131  int64_t num_segments = neighbors_row_splits.GetShape()[0] - 1;
132  Tensor indices_sorted =
133  Tensor::Empty({num_indices}, index_dtype, device);
134  Tensor distances_sorted = Tensor::Empty({num_indices}, dtype, device);
135 
136  // Determine temp_size for sorting
137  impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices,
138  num_segments,
139  neighbors_row_splits.GetDataPtr<int64_t>(),
140  indices_unsorted.GetDataPtr<TIndex>(),
141  distances_unsorted.GetDataPtr<T>(),
142  indices_sorted.GetDataPtr<TIndex>(),
143  distances_sorted.GetDataPtr<T>());
144 
145  temp_tensor = Tensor::Empty({int64_t(temp_size)}, Dtype::UInt8, device);
146  temp_ptr = temp_tensor.GetDataPtr();
147 
148  // Actually run the sorting.
149  impl::SortPairs(temp_ptr, temp_size, texture_alignment, num_indices,
150  num_segments,
151  neighbors_row_splits.GetDataPtr<int64_t>(),
152  indices_unsorted.GetDataPtr<TIndex>(),
153  distances_unsorted.GetDataPtr<T>(),
154  indices_sorted.GetDataPtr<TIndex>(),
155  distances_sorted.GetDataPtr<T>());
156  neighbors_index = indices_sorted;
157  neighbors_distance = distances_sorted;
158  }
159 }
160 
161 template <class T, class TIndex>
162 void HybridSearchCUDA(const Tensor& points,
163  const Tensor& queries,
164  double radius,
165  int max_knn,
166  const Tensor& points_row_splits,
167  const Tensor& queries_row_splits,
168  const Tensor& hash_table_splits,
169  const Tensor& hash_table_index,
170  const Tensor& hash_table_cell_splits,
171  const Metric metric,
172  Tensor& neighbors_index,
173  Tensor& neighbors_count,
174  Tensor& neighbors_distance) {
175  CUDAScopedDevice scoped_device(points.GetDevice());
176  const cudaStream_t stream = 0;
177 
178  Device device = points.GetDevice();
179 
180  NeighborSearchAllocator<T, TIndex> output_allocator(device);
181 
182  impl::HybridSearchCUDA<T, TIndex>(
183  stream, points.GetShape()[0], points.GetDataPtr<T>(),
184  queries.GetShape()[0], queries.GetDataPtr<T>(), T(radius), max_knn,
185  points_row_splits.GetShape()[0],
186  points_row_splits.GetDataPtr<int64_t>(),
187  queries_row_splits.GetShape()[0],
188  queries_row_splits.GetDataPtr<int64_t>(),
189  hash_table_splits.GetDataPtr<uint32_t>(),
190  hash_table_cell_splits.GetShape()[0],
191  hash_table_cell_splits.GetDataPtr<uint32_t>(),
192  hash_table_index.GetDataPtr<uint32_t>(), metric, output_allocator);
193 
194  neighbors_index = output_allocator.NeighborsIndex();
195  neighbors_distance = output_allocator.NeighborsDistance();
196  neighbors_count = output_allocator.NeighborsCount();
197 }
198 
199 #define INSTANTIATE_BUILD(T) \
200  template void BuildSpatialHashTableCUDA<T>( \
201  const Tensor& points, double radius, \
202  const Tensor& points_row_splits, const Tensor& hash_table_splits, \
203  Tensor& hash_table_index, Tensor& hash_table_cell_splits);
204 
205 #define INSTANTIATE_RADIUS(T, TIndex) \
206  template void FixedRadiusSearchCUDA<T, TIndex>( \
207  const Tensor& points, const Tensor& queries, double radius, \
208  const Tensor& points_row_splits, const Tensor& queries_row_splits, \
209  const Tensor& hash_table_splits, const Tensor& hash_table_index, \
210  const Tensor& hash_table_cell_splits, const Metric metric, \
211  const bool ignore_query_point, const bool return_distances, \
212  const bool sort, Tensor& neighbors_index, \
213  Tensor& neighbors_row_splits, Tensor& neighbors_distance);
214 
215 #define INSTANTIATE_HYBRID(T, TIndex) \
216  template void HybridSearchCUDA<T, TIndex>( \
217  const Tensor& points, const Tensor& queries, double radius, \
218  int max_knn, const Tensor& points_row_splits, \
219  const Tensor& queries_row_splits, const Tensor& hash_table_splits, \
220  const Tensor& hash_table_index, \
221  const Tensor& hash_table_cell_splits, const Metric metric, \
222  Tensor& neighbors_index, Tensor& neighbors_count, \
223  Tensor& neighbors_distance);
224 
225 INSTANTIATE_BUILD(float)
226 INSTANTIATE_BUILD(double)
227 
228 INSTANTIATE_RADIUS(float, int32_t)
229 INSTANTIATE_RADIUS(float, int64_t)
230 INSTANTIATE_RADIUS(double, int32_t)
231 INSTANTIATE_RADIUS(double, int64_t)
232 
233 INSTANTIATE_HYBRID(float, int32_t)
234 INSTANTIATE_HYBRID(float, int64_t)
235 INSTANTIATE_HYBRID(double, int32_t)
236 INSTANTIATE_HYBRID(double, int64_t)
237 
238 } // namespace nns
239 } // namespace core
240 } // namespace cloudViewer