1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
13 #include <cub/cub.cuh>
15 #include "core/nns/MemoryAllocation.h"
16 #include "core/nns/NeighborSearchCommon.h"
17 #include "utility/MiniVec.h"
19 namespace cloudViewer {
27 using Vec3 = utility::MiniVec<T, 3>;
29 /// Computes the distance of two points and tests if the distance is below a
32 /// \tparam METRIC The distance metric. One of L1, L2, Linf.
33 /// \tparam T Floating point type for the distances.
35 /// \param p1 A 3D point
36 /// \param p2 Another 3D point
37 /// \param dist Output parameter for the distance.
38 /// \param threshold The scalar threshold.
40 /// \return Returns true if the distance is <= threshold.
42 template <int METRIC = L2, class T>
43 inline __device__ bool NeighborTest(const Vec3<T>& p1,
49 Vec3<T> d = (p1 - p2).abs();
50 *dist = d[0] > d[1] ? d[0] : d[1];
51 *dist = *dist > d[2] ? *dist : d[2];
52 } else if (METRIC == L1) {
53 Vec3<T> d = (p1 - p2).abs();
54 *dist = (d[0] + d[1] + d[2]);
59 result = *dist <= threshold;
63 /// Kernel for CountHashTableEntries
65 __global__ void CountHashTableEntriesKernel(uint32_t* count_table,
66 size_t hash_table_size,
68 const T* const __restrict__ points,
70 const int idx = blockDim.x * blockIdx.x + threadIdx.x;
71 if (idx >= num_points) return;
73 Vec3<T> pos(&points[idx * 3]);
75 Vec3<int> voxel_index = ComputeVoxelIndex(pos, inv_voxel_size);
76 size_t hash = SpatialHash(voxel_index) % hash_table_size;
77 atomicAdd(&count_table[hash + 1], 1);
80 /// Counts for each hash entry the number of points that map to this entry.
82 /// \param count_table Pointer to the table for counting.
83 /// The first element will not be used, i.e. the
84 /// number of points for the first hash entry is in count_table[1].
85 /// This array must be initialized before calling this function.
87 /// \param count_table_size This is the size of the hash table + 1.
89 /// \param inv_voxel_size Reciproval of the voxel size
91 /// \param points Array with the 3D point positions.
93 /// \param num_points The number of points.
96 void CountHashTableEntries(const cudaStream_t& stream,
97 uint32_t* count_table,
98 size_t count_table_size,
102 const int BLOCKSIZE = 64;
103 dim3 block(BLOCKSIZE, 1, 1);
105 grid.x = utility::DivUp(num_points, block.x);
108 CountHashTableEntriesKernel<T><<<grid, block, 0, stream>>>(
109 count_table, count_table_size - 1, inv_voxel_size, points,
113 /// Kernel for ComputePointIndexTable
115 __global__ void ComputePointIndexTableKernel(
116 uint32_t* __restrict__ point_index_table,
117 uint32_t* __restrict__ count_tmp,
118 const uint32_t* const __restrict__ hash_table_cell_splits,
119 size_t hash_table_size,
121 const T* const __restrict__ points,
122 const size_t points_start_idx,
123 const size_t points_end_idx) {
124 const int idx = blockDim.x * blockIdx.x + threadIdx.x + points_start_idx;
125 if (idx >= points_end_idx) return;
127 Vec3<T> pos(&points[idx * 3]);
129 Vec3<int> voxel_index = ComputeVoxelIndex(pos, inv_voxel_size);
130 size_t hash = SpatialHash(voxel_index[0], voxel_index[1], voxel_index[2]) %
133 point_index_table[hash_table_cell_splits[hash] +
134 atomicAdd(&count_tmp[hash], 1)] = idx;
137 /// Writes the index of the points to the hash cells.
139 /// \param point_index_table The output array storing the point indices for
140 /// all cells. Start and end of each cell is defined by
141 /// \p hash_table_prefix_sum
143 /// \param count_tmp Temporary memory of size
144 /// \p hash_table_cell_splits_size.
146 /// \param hash_table_cell_splits The row splits array describing the start
147 /// and end of each cell.
149 /// \param hash_table_cell_splits_size The size of the hash table.
151 /// \param inv_voxel_size Reciproval of the voxel size
153 /// \param points Array with the 3D point positions.
155 /// \param num_points The number of points.
158 void ComputePointIndexTable(
159 const cudaStream_t& stream,
160 uint32_t* __restrict__ point_index_table,
161 uint32_t* __restrict__ count_tmp,
162 const uint32_t* const __restrict__ hash_table_cell_splits,
163 size_t hash_table_cell_splits_size,
165 const T* const __restrict__ points,
166 size_t points_start_idx,
167 size_t points_end_idx) {
168 cudaMemsetAsync(count_tmp, 0,
169 sizeof(uint32_t) * hash_table_cell_splits_size, stream);
170 size_t num_points = points_end_idx - points_start_idx;
172 const int BLOCKSIZE = 64;
173 dim3 block(BLOCKSIZE, 1, 1);
175 grid.x = utility::DivUp(num_points, block.x);
178 ComputePointIndexTableKernel<T><<<grid, block, 0, stream>>>(
179 point_index_table, count_tmp, hash_table_cell_splits,
180 hash_table_cell_splits_size - 1, inv_voxel_size, points,
181 points_start_idx, points_end_idx);
184 /// Kernel for CountNeighbors
185 template <int METRIC, bool IGNORE_QUERY_POINT, class T>
186 __global__ void CountNeighborsKernel(
187 uint32_t* __restrict__ neighbors_count,
188 const uint32_t* const __restrict__ point_index_table,
189 const uint32_t* const __restrict__ hash_table_cell_splits,
190 size_t hash_table_size,
191 const T* const __restrict__ query_points,
193 const T* const __restrict__ points,
194 const T inv_voxel_size,
197 int query_idx = blockDim.x * blockIdx.x + threadIdx.x;
198 if (query_idx >= num_queries) return;
200 int count = 0; // counts the number of neighbors for this query point
202 Vec3<T> query_pos(query_points[query_idx * 3 + 0],
203 query_points[query_idx * 3 + 1],
204 query_points[query_idx * 3 + 2]);
205 Vec3<int> voxel_index = ComputeVoxelIndex(query_pos, inv_voxel_size);
206 int hash = SpatialHash(voxel_index[0], voxel_index[1], voxel_index[2]) %
209 int bins_to_visit[8] = {hash, -1, -1, -1, -1, -1, -1, -1};
211 for (int dz = -1; dz <= 1; dz += 2)
212 for (int dy = -1; dy <= 1; dy += 2)
213 for (int dx = -1; dx <= 1; dx += 2) {
214 Vec3<T> p = query_pos + radius * Vec3<T>(T(dx), T(dy), T(dz));
215 voxel_index = ComputeVoxelIndex(p, inv_voxel_size);
216 hash = SpatialHash(voxel_index[0], voxel_index[1],
220 // insert without duplicates
221 for (int i = 0; i < 8; ++i) {
222 if (bins_to_visit[i] == hash) {
224 } else if (bins_to_visit[i] == -1) {
225 bins_to_visit[i] = hash;
231 for (int bin_i = 0; bin_i < 8; ++bin_i) {
232 int bin = bins_to_visit[bin_i];
233 if (bin == -1) break;
235 size_t begin_idx = hash_table_cell_splits[bin];
236 size_t end_idx = hash_table_cell_splits[bin + 1];
238 for (size_t j = begin_idx; j < end_idx; ++j) {
239 uint32_t idx = point_index_table[j];
241 Vec3<T> p(&points[idx * 3 + 0]);
242 if (IGNORE_QUERY_POINT) {
243 if ((query_pos == p).all()) continue;
247 if (NeighborTest<METRIC>(p, query_pos, &dist, threshold)) ++count;
250 neighbors_count[query_idx] = count;
253 /// Count the number of neighbors for each query point
255 /// \param neighbors_count Output array for counting the number of neighbors.
256 /// The size of the array is \p num_queries.
258 /// \param point_index_table The array storing the point indices for all
259 /// cells. Start and end of each cell is defined by \p
260 /// hash_table_cell_splits
262 /// \param hash_table_cell_splits The row splits array describing the start
263 /// and end of each cell.
265 /// \param hash_table_cell_splits_size This is the length of the
266 /// hash_table_cell_splits array.
268 /// \param query_points Array with the 3D query positions. This may be the
269 /// same array as \p points.
271 /// \param num_queries The number of query points.
273 /// \param points Array with the 3D point positions.
275 /// \param num_points The number of points.
277 /// \param inv_voxel_size Reciproval of the voxel size
279 /// \param radius The search radius.
281 /// \param metric One of L1, L2, Linf. Defines the distance metric for the
284 /// \param ignore_query_point If true then points with the same position as
285 /// the query point will be ignored.
288 void CountNeighbors(const cudaStream_t& stream,
289 uint32_t* neighbors_count,
290 const uint32_t* const point_index_table,
291 const uint32_t* const hash_table_cell_splits,
292 size_t hash_table_cell_splits_size,
293 const T* const query_points,
295 const T* const points,
296 const T inv_voxel_size,
299 const bool ignore_query_point) {
300 const T threshold = (metric == L2 ? radius * radius : radius);
302 const int BLOCKSIZE = 64;
303 dim3 block(BLOCKSIZE, 1, 1);
305 grid.x = utility::DivUp(num_queries, block.x);
308 #define FN_PARAMETERS \
309 neighbors_count, point_index_table, hash_table_cell_splits, \
310 hash_table_cell_splits_size - 1, query_points, num_queries, \
311 points, inv_voxel_size, radius, threshold
313 #define CALL_TEMPLATE(METRIC) \
314 if (METRIC == metric) { \
315 if (ignore_query_point) \
316 CountNeighborsKernel<METRIC, true, T> \
317 <<<grid, block, 0, stream>>>(FN_PARAMETERS); \
319 CountNeighborsKernel<METRIC, false, T> \
320 <<<grid, block, 0, stream>>>(FN_PARAMETERS); \
332 /// Kernel for WriteNeighborsIndicesAndDistances
336 bool IGNORE_QUERY_POINT,
337 bool RETURN_DISTANCES>
338 __global__ void WriteNeighborsIndicesAndDistancesKernel(
339 TIndex* __restrict__ indices,
340 T* __restrict__ distances,
341 const int64_t* const __restrict__ neighbors_row_splits,
342 const uint32_t* const __restrict__ point_index_table,
343 const uint32_t* const __restrict__ hash_table_cell_splits,
344 size_t hash_table_size,
345 const T* const __restrict__ query_points,
347 const T* const __restrict__ points,
348 const T inv_voxel_size,
351 int query_idx = blockDim.x * blockIdx.x + threadIdx.x;
352 if (query_idx >= num_queries) return;
354 int count = 0; // counts the number of neighbors for this query point
356 size_t indices_offset = neighbors_row_splits[query_idx];
358 Vec3<T> query_pos(query_points[query_idx * 3 + 0],
359 query_points[query_idx * 3 + 1],
360 query_points[query_idx * 3 + 2]);
361 Vec3<int> voxel_index = ComputeVoxelIndex(query_pos, inv_voxel_size);
362 int hash = SpatialHash(voxel_index) % hash_table_size;
364 int bins_to_visit[8] = {hash, -1, -1, -1, -1, -1, -1, -1};
366 for (int dz = -1; dz <= 1; dz += 2) {
367 for (int dy = -1; dy <= 1; dy += 2) {
368 for (int dx = -1; dx <= 1; dx += 2) {
369 Vec3<T> p = query_pos + radius * Vec3<T>(T(dx), T(dy), T(dz));
370 voxel_index = ComputeVoxelIndex(p, inv_voxel_size);
371 hash = SpatialHash(voxel_index) % hash_table_size;
373 // insert without duplicates
374 for (int i = 0; i < 8; ++i) {
375 if (bins_to_visit[i] == hash) {
377 } else if (bins_to_visit[i] == -1) {
378 bins_to_visit[i] = hash;
386 for (int bin_i = 0; bin_i < 8; ++bin_i) {
387 int bin = bins_to_visit[bin_i];
388 if (bin == -1) break;
390 size_t begin_idx = hash_table_cell_splits[bin];
391 size_t end_idx = hash_table_cell_splits[bin + 1];
393 for (size_t j = begin_idx; j < end_idx; ++j) {
394 uint32_t idx = point_index_table[j];
396 Vec3<T> p(&points[idx * 3 + 0]);
397 if (IGNORE_QUERY_POINT) {
398 if ((query_pos == p).all()) continue;
402 if (NeighborTest<METRIC>(p, query_pos, &dist, threshold)) {
403 indices[indices_offset + count] = idx;
404 if (RETURN_DISTANCES) {
405 distances[indices_offset + count] = dist;
413 /// Write indices and distances of neighbors for each query point
415 /// \param indices Output array with the neighbors indices.
417 /// \param distances Output array with the neighbors distances. May be null
418 /// if return_distances is false.
420 /// \param neighbors_row_splits This is the prefix sum which describes
421 /// start and end of the neighbors and distances for each query point.
423 /// \param point_index_table The array storing the point indices for all
424 /// cells. Start and end of each cell is defined by \p
425 /// hash_table_cell_splits
427 /// \param hash_table_cell_splits The row splits array describing the start
428 /// and end of each cell.
430 /// \param hash_table_cell_splits_size This is the length of the
431 /// hash_table_cell_splits array.
433 /// \param query_points Array with the 3D query positions. This may be the
434 /// same array as \p points.
436 /// \param num_queries The number of query points.
438 /// \param points Array with the 3D point positions.
440 /// \param num_points The number of points.
442 /// \param inv_voxel_size Reciproval of the voxel size
444 /// \param radius The search radius.
446 /// \param metric One of L1, L2, Linf. Defines the distance metric for the
449 /// \param ignore_query_point If true then points with the same position as
450 /// the query point will be ignored.
452 /// \param return_distances If true then this function will return the
453 /// distances for each neighbor to its query point in the same format
455 /// Note that for the L2 metric the squared distances will be returned!!
456 template <class T, class TIndex>
457 void WriteNeighborsIndicesAndDistances(
458 const cudaStream_t& stream,
461 const int64_t* const neighbors_row_splits,
462 const uint32_t* const point_index_table,
463 const uint32_t* const hash_table_cell_splits,
464 size_t hash_table_cell_splits_size,
465 const T* const query_points,
467 const T* const points,
468 const T inv_voxel_size,
471 const bool ignore_query_point,
472 const bool return_distances) {
473 const T threshold = (metric == L2 ? radius * radius : radius);
475 const int BLOCKSIZE = 64;
476 dim3 block(BLOCKSIZE, 1, 1);
478 grid.x = utility::DivUp(num_queries, block.x);
481 #define FN_PARAMETERS \
482 indices, distances, neighbors_row_splits, point_index_table, \
483 hash_table_cell_splits, hash_table_cell_splits_size, query_points, \
484 num_queries, points, inv_voxel_size, radius, threshold
486 #define CALL_TEMPLATE(METRIC, IGNORE_QUERY_POINT, RETURN_DISTANCES) \
487 if (METRIC == metric && IGNORE_QUERY_POINT == ignore_query_point && \
488 RETURN_DISTANCES == return_distances) { \
489 WriteNeighborsIndicesAndDistancesKernel< \
490 T, TIndex, METRIC, IGNORE_QUERY_POINT, RETURN_DISTANCES> \
491 <<<grid, block, 0, stream>>>(FN_PARAMETERS); \
494 #define CALL_TEMPLATE2(METRIC) \
495 CALL_TEMPLATE(METRIC, true, true) \
496 CALL_TEMPLATE(METRIC, true, false) \
497 CALL_TEMPLATE(METRIC, false, true) \
498 CALL_TEMPLATE(METRIC, false, false)
500 #define CALL_TEMPLATE3 \
508 #undef CALL_TEMPLATE2
509 #undef CALL_TEMPLATE3
514 /// Kernel for WriteNeighborsHybrid
515 template <class T, class TIndex, int METRIC, bool RETURN_DISTANCES>
516 __global__ void WriteNeighborsHybridKernel(
517 TIndex* __restrict__ indices,
518 T* __restrict__ distances,
519 TIndex* __restrict__ counts,
521 const uint32_t* const __restrict__ point_index_table,
522 const uint32_t* const __restrict__ hash_table_cell_splits,
523 size_t hash_table_size,
524 const T* const __restrict__ query_points,
526 const T* const __restrict__ points,
527 const T inv_voxel_size,
531 int query_idx = blockDim.x * blockIdx.x + threadIdx.x;
532 if (query_idx >= num_queries) return;
534 int count = 0; // counts the number of neighbors for this query point
536 size_t indices_offset = query_offset + max_knn * query_idx;
538 Vec3<T> query_pos(query_points[query_idx * 3 + 0],
539 query_points[query_idx * 3 + 1],
540 query_points[query_idx * 3 + 2]);
541 Vec3<int> voxel_index = ComputeVoxelIndex(query_pos, inv_voxel_size);
542 int hash = SpatialHash(voxel_index) % hash_table_size;
544 int bins_to_visit[8] = {hash, -1, -1, -1, -1, -1, -1, -1};
546 for (int dz = -1; dz <= 1; dz += 2) {
547 for (int dy = -1; dy <= 1; dy += 2) {
548 for (int dx = -1; dx <= 1; dx += 2) {
549 Vec3<T> p = query_pos + radius * Vec3<T>(T(dx), T(dy), T(dz));
550 voxel_index = ComputeVoxelIndex(p, inv_voxel_size);
551 hash = SpatialHash(voxel_index) % hash_table_size;
553 // insert without duplicates
554 for (int i = 0; i < 8; ++i) {
555 if (bins_to_visit[i] == hash) {
557 } else if (bins_to_visit[i] == -1) {
558 bins_to_visit[i] = hash;
569 for (int bin_i = 0; bin_i < 8; ++bin_i) {
570 int bin = bins_to_visit[bin_i];
571 if (bin == -1) break;
573 size_t begin_idx = hash_table_cell_splits[bin];
574 size_t end_idx = hash_table_cell_splits[bin + 1];
576 for (size_t j = begin_idx; j < end_idx; ++j) {
577 uint32_t idx = point_index_table[j];
579 Vec3<T> p(&points[idx * 3 + 0]);
582 if (NeighborTest<METRIC>(p, query_pos, &dist, threshold)) {
583 // If count if less than max_knn, record idx and dist.
584 if (count < max_knn) {
585 indices[indices_offset + count] = idx;
586 distances[indices_offset + count] = dist;
587 // Update max_index and max_value.
588 if (count == 0 || max_value < dist) {
595 // If dist is smaller than current max_value.
596 if (max_value > dist) {
597 // Replace idx and dist at current max_index.
598 indices[indices_offset + max_index] = idx;
599 distances[indices_offset + max_index] = dist;
603 for (auto k = 0; k < max_knn; ++k) {
604 if (distances[indices_offset + k] > max_value) {
606 max_value = distances[indices_offset + k];
615 counts[query_idx] = count;
618 for (int i = 0; i < count - 1; ++i) {
619 for (int j = 0; j < count - i - 1; ++j) {
620 if (distances[indices_offset + j] >
621 distances[indices_offset + j + 1]) {
622 T dist_tmp = distances[indices_offset + j];
623 TIndex ind_tmp = indices[indices_offset + j];
624 distances[indices_offset + j] =
625 distances[indices_offset + j + 1];
626 indices[indices_offset + j] = indices[indices_offset + j + 1];
627 distances[indices_offset + j + 1] = dist_tmp;
628 indices[indices_offset + j + 1] = ind_tmp;
634 /// Write indices and distances for each query point in hybrid search mode.
636 /// \param indices Output array with the neighbors indices.
638 /// \param distances Output array with the neighbors distances. May be null
639 /// if return_distances is false.
641 /// \param counts Output array with the neighbour counts.
643 /// \param point_index_table The array storing the point indices for all
644 /// cells. Start and end of each cell is defined by
645 /// \p hash_table_cell_splits
647 /// \param hash_table_cell_splits The row splits array describing the start
648 /// and end of each cell.
650 /// \param hash_table_cell_splits_size This is the length of the
651 /// hash_table_cell_splits array.
653 /// \param query_points Array with the 3D query positions. This may be the
654 /// same array as \p points.
656 /// \param num_queries The number of query points.
658 /// \param points Array with the 3D point positions.
660 /// \param num_points The number of points.
662 /// \param inv_voxel_size Reciproval of the voxel size
664 /// \param radius The search radius.
666 /// \param metric One of L1, L2, Linf. Defines the distance metric for the
669 /// \param ignore_query_point If true then points with the same position as
670 /// the query point will be ignored.
672 /// \param return_distances If true then this function will return the
673 /// distances for each neighbor to its query point in the same format
675 /// Note that for the L2 metric the squared distances will be returned!!
676 template <class T, class TIndex>
677 void WriteNeighborsHybrid(const cudaStream_t& stream,
682 const uint32_t* const point_index_table,
683 const uint32_t* const hash_table_cell_splits,
684 size_t hash_table_cell_splits_size,
685 const T* const query_points,
687 const T* const points,
688 const T inv_voxel_size,
692 const bool return_distances) {
693 const T threshold = (metric == L2 ? radius * radius : radius);
695 const int BLOCKSIZE = 64;
696 dim3 block(BLOCKSIZE, 1, 1);
698 grid.x = utility::DivUp(num_queries, block.x);
701 #define FN_PARAMETERS \
702 indices, distances, counts, query_offset, point_index_table, \
703 hash_table_cell_splits, hash_table_cell_splits_size - 1, \
704 query_points, num_queries, points, inv_voxel_size, radius, \
707 #define CALL_TEMPLATE(METRIC, RETURN_DISTANCES) \
708 if (METRIC == metric && RETURN_DISTANCES == return_distances) { \
709 WriteNeighborsHybridKernel<T, TIndex, METRIC, RETURN_DISTANCES> \
710 <<<grid, block, 0, stream>>>(FN_PARAMETERS); \
713 #define CALL_TEMPLATE2(METRIC) \
714 CALL_TEMPLATE(METRIC, true) \
715 CALL_TEMPLATE(METRIC, false)
717 #define CALL_TEMPLATE3 \
725 #undef CALL_TEMPLATE2
726 #undef CALL_TEMPLATE3
734 void BuildSpatialHashTableCUDA(const cudaStream_t& stream,
737 int texture_alignment,
738 const size_t num_points,
739 const T* const points,
741 const size_t points_row_splits_size,
742 const int64_t* points_row_splits,
743 const uint32_t* hash_table_splits,
744 const size_t hash_table_cell_splits_size,
745 uint32_t* hash_table_cell_splits,
746 uint32_t* hash_table_index) {
747 const bool get_temp_size = !temp;
750 temp = (char*)1; // worst case pointer alignment
751 temp_size = std::numeric_limits<int64_t>::max();
754 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
756 std::pair<uint32_t*, size_t> count_tmp =
757 mem_temp.Alloc<uint32_t>(hash_table_cell_splits_size);
759 const int batch_size = points_row_splits_size - 1;
760 const T voxel_size = 2 * radius;
761 const T inv_voxel_size = 1 / voxel_size;
763 // count number of points per hash entry
764 if (!get_temp_size) {
765 cudaMemsetAsync(count_tmp.first, 0, sizeof(uint32_t) * count_tmp.second,
768 for (int i = 0; i < batch_size; ++i) {
769 const size_t hash_table_size =
770 hash_table_splits[i + 1] - hash_table_splits[i];
771 const size_t first_cell_idx = hash_table_splits[i];
772 const size_t num_points_i =
773 points_row_splits[i + 1] - points_row_splits[i];
774 const T* const points_i = points + 3 * points_row_splits[i];
776 CountHashTableEntries(stream, count_tmp.first + first_cell_idx,
777 hash_table_size + 1, inv_voxel_size, points_i,
782 // compute prefix sum of the hash entry counts and store in
783 // hash_table_cell_splits
785 std::pair<void*, size_t> inclusive_scan_temp(nullptr, 0);
786 cub::DeviceScan::InclusiveSum(inclusive_scan_temp.first,
787 inclusive_scan_temp.second,
788 count_tmp.first, hash_table_cell_splits,
789 count_tmp.second, stream);
791 inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second);
793 if (!get_temp_size) {
794 cub::DeviceScan::InclusiveSum(
795 inclusive_scan_temp.first, inclusive_scan_temp.second,
796 count_tmp.first, hash_table_cell_splits, count_tmp.second,
800 mem_temp.Free(inclusive_scan_temp);
803 // now compute the global indices which allows us to lookup the point index
804 // for the entries in the hash cell
805 if (!get_temp_size) {
806 for (int i = 0; i < batch_size; ++i) {
807 const size_t hash_table_size =
808 hash_table_splits[i + 1] - hash_table_splits[i];
809 const size_t first_cell_idx = hash_table_splits[i];
810 const size_t points_start_idx = points_row_splits[i];
811 const size_t points_end_idx = points_row_splits[i + 1];
812 ComputePointIndexTable(stream, hash_table_index, count_tmp.first,
813 hash_table_cell_splits + first_cell_idx,
814 hash_table_size + 1, inv_voxel_size, points,
815 points_start_idx, points_end_idx);
819 mem_temp.Free(count_tmp);
822 // return the memory peak as the required temporary memory size.
823 temp_size = mem_temp.MaxUsed();
828 template <class T, class TIndex>
829 void SortPairs(void* temp,
831 int texture_alignment,
833 int64_t num_segments,
834 const int64_t* query_neighbors_row_splits,
835 TIndex* indices_unsorted,
836 T* distances_unsorted,
837 TIndex* indices_sorted,
838 T* distances_sorted) {
839 const bool get_temp_size = !temp;
842 temp = (char*)1; // worst case pointer alignment
843 temp_size = std::numeric_limits<int64_t>::max();
846 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
848 std::pair<void*, size_t> sort_temp(nullptr, 0);
850 cub::DeviceSegmentedRadixSort::SortPairs(
851 sort_temp.first, sort_temp.second, distances_unsorted,
852 distances_sorted, indices_unsorted, indices_sorted, num_indices,
853 num_segments, query_neighbors_row_splits,
854 query_neighbors_row_splits + 1);
855 sort_temp = mem_temp.Alloc(sort_temp.second);
857 if (!get_temp_size) {
858 cub::DeviceSegmentedRadixSort::SortPairs(
859 sort_temp.first, sort_temp.second, distances_unsorted,
860 distances_sorted, indices_unsorted, indices_sorted, num_indices,
861 num_segments, query_neighbors_row_splits,
862 query_neighbors_row_splits + 1);
864 mem_temp.Free(sort_temp);
867 // return the memory peak as the required temporary memory size.
868 temp_size = mem_temp.MaxUsed();
873 template <class T, class TIndex, class OUTPUT_ALLOCATOR>
874 void FixedRadiusSearchCUDA(const cudaStream_t& stream,
877 int texture_alignment,
878 int64_t* query_neighbors_row_splits,
880 const T* const points,
882 const T* const queries,
884 const size_t points_row_splits_size,
885 const int64_t* const points_row_splits,
886 const size_t queries_row_splits_size,
887 const int64_t* const queries_row_splits,
888 const uint32_t* const hash_table_splits,
889 size_t hash_table_cell_splits_size,
890 const uint32_t* const hash_table_cell_splits,
891 const uint32_t* const hash_table_index,
893 const bool ignore_query_point,
894 const bool return_distances,
895 OUTPUT_ALLOCATOR& output_allocator) {
896 const bool get_temp_size = !temp;
899 temp = (char*)1; // worst case pointer alignment
900 temp_size = std::numeric_limits<int64_t>::max();
903 // return empty output arrays if there are no points
904 if ((0 == num_points || 0 == num_queries) && !get_temp_size) {
905 cudaMemsetAsync(query_neighbors_row_splits, 0,
906 sizeof(int64_t) * (num_queries + 1), stream);
908 output_allocator.AllocIndices(&indices_ptr, 0);
911 output_allocator.AllocDistances(&distances_ptr, 0);
916 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
918 const int batch_size = points_row_splits_size - 1;
919 const T voxel_size = 2 * radius;
920 const T inv_voxel_size = 1 / voxel_size;
922 std::pair<uint32_t*, size_t> query_neighbors_count =
923 mem_temp.Alloc<uint32_t>(num_queries);
925 // we need this value to compute the size of the index array
926 if (!get_temp_size) {
927 for (int i = 0; i < batch_size; ++i) {
928 const size_t hash_table_size =
929 hash_table_splits[i + 1] - hash_table_splits[i];
930 const size_t first_cell_idx = hash_table_splits[i];
931 const size_t queries_start_idx = queries_row_splits[i];
932 const T* const queries_i = queries + 3 * queries_row_splits[i];
933 const size_t num_queries_i =
934 queries_row_splits[i + 1] - queries_row_splits[i];
937 stream, query_neighbors_count.first + queries_start_idx,
938 hash_table_index, hash_table_cell_splits + first_cell_idx,
939 hash_table_size + 1, queries_i, num_queries_i, points,
940 inv_voxel_size, radius, metric, ignore_query_point);
944 int64_t last_prefix_sum_entry = 0;
946 std::pair<void*, size_t> inclusive_scan_temp(nullptr, 0);
947 cub::DeviceScan::InclusiveSum(
948 inclusive_scan_temp.first, inclusive_scan_temp.second,
949 query_neighbors_count.first, query_neighbors_row_splits + 1,
950 num_queries, stream);
952 inclusive_scan_temp = mem_temp.Alloc(inclusive_scan_temp.second);
954 if (!get_temp_size) {
955 // set first element to zero
956 cudaMemsetAsync(query_neighbors_row_splits, 0, sizeof(int64_t),
958 cub::DeviceScan::InclusiveSum(
959 inclusive_scan_temp.first, inclusive_scan_temp.second,
960 query_neighbors_count.first, query_neighbors_row_splits + 1,
961 num_queries, stream);
963 // get the last value
964 cudaMemcpyAsync(&last_prefix_sum_entry,
965 query_neighbors_row_splits + num_queries,
966 sizeof(int64_t), cudaMemcpyDeviceToHost, stream);
967 // wait for the async copies
968 while (cudaErrorNotReady == cudaStreamQuery(stream)) { /*empty*/
971 mem_temp.Free(inclusive_scan_temp);
973 mem_temp.Free(query_neighbors_count);
976 // return the memory peak as the required temporary memory size.
977 temp_size = mem_temp.MaxUsed();
981 if (!get_temp_size) {
982 // allocate the output array for the neighbor indices
983 const size_t num_indices = last_prefix_sum_entry;
988 output_allocator.AllocIndices(&indices_ptr, num_indices);
989 output_allocator.AllocDistances(&distances_ptr, num_indices);
990 for (int i = 0; i < batch_size; ++i) {
991 const size_t hash_table_size =
992 hash_table_splits[i + 1] - hash_table_splits[i];
993 const size_t first_cell_idx = hash_table_splits[i];
994 const T* const queries_i = queries + 3 * queries_row_splits[i];
995 const size_t num_queries_i =
996 queries_row_splits[i + 1] - queries_row_splits[i];
998 WriteNeighborsIndicesAndDistances(
999 stream, indices_ptr, distances_ptr,
1000 query_neighbors_row_splits + queries_row_splits[i],
1001 hash_table_index, hash_table_cell_splits + first_cell_idx,
1002 hash_table_size, queries_i, num_queries_i, points,
1003 inv_voxel_size, radius, metric, ignore_query_point,
1010 template <class T, class TIndex, class OUTPUT_ALLOCATOR>
1011 void HybridSearchCUDA(const cudaStream_t stream,
1013 const T* const points,
1015 const T* const queries,
1018 const size_t points_row_splits_size,
1019 const int64_t* const points_row_splits,
1020 const size_t queries_row_splits_size,
1021 const int64_t* const queries_row_splits,
1022 const uint32_t* const hash_table_splits,
1023 size_t hash_table_cell_splits_size,
1024 const uint32_t* const hash_table_cell_splits,
1025 const uint32_t* const hash_table_index,
1026 const Metric metric,
1027 OUTPUT_ALLOCATOR& output_allocator) {
1028 // return empty output arrays if there are no points
1029 if (0 == num_points || 0 == num_queries) {
1030 TIndex* indices_ptr;
1031 output_allocator.AllocIndices(&indices_ptr, 0);
1034 output_allocator.AllocDistances(&distances_ptr, 0);
1037 output_allocator.AllocCounts(&counts_ptr, 0);
1041 const int batch_size = points_row_splits_size - 1;
1042 const T voxel_size = 2 * radius;
1043 const T inv_voxel_size = 1 / voxel_size;
1045 // Allocate output pointers.
1046 const size_t num_indices = num_queries * max_knn;
1048 TIndex* indices_ptr;
1049 output_allocator.AllocIndices(&indices_ptr, num_indices, -1);
1052 output_allocator.AllocDistances(&distances_ptr, num_indices, 0);
1055 output_allocator.AllocCounts(&counts_ptr, num_queries, 0);
1057 for (int i = 0; i < batch_size; ++i) {
1058 const size_t hash_table_size =
1059 hash_table_splits[i + 1] - hash_table_splits[i];
1060 const size_t query_offset = max_knn * queries_row_splits[i];
1061 const size_t first_cell_idx = hash_table_splits[i];
1062 const T* const queries_i = queries + 3 * queries_row_splits[i];
1063 const size_t num_queries_i =
1064 queries_row_splits[i + 1] - queries_row_splits[i];
1066 WriteNeighborsHybrid(
1067 stream, indices_ptr, distances_ptr, counts_ptr, query_offset,
1068 hash_table_index, hash_table_cell_splits + first_cell_idx,
1069 hash_table_size + 1, queries_i, num_queries_i, points,
1070 inv_voxel_size, radius, max_knn, metric, true);
1078 } // namespace cloudViewer