ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
contrib_nns.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <Parallel.h>
9 
10 #include <numeric>
11 
13 
14 namespace cloudViewer {
15 namespace ml {
16 namespace contrib {
17 
18 const core::Tensor KnnSearch(const core::Tensor& query_points,
19  const core::Tensor& dataset_points,
20  int knn) {
21  // Check dtype.
22  if (dataset_points.GetDtype() != query_points.GetDtype()) {
23  utility::LogError("Point dtype mismatch {} != {}.",
24  dataset_points.GetDtype().ToString(),
25  query_points.GetDtype().ToString());
26  }
27 
28  // Check shape.
29  if (query_points.NumDims() != 2) {
30  utility::LogError("query_points must be of shape {n_query_points, d}.");
31  }
32  if (dataset_points.NumDims() != 2) {
34  "dataset_points must be of shape {n_dataset_points, d}.");
35  }
36  if (query_points.GetShape()[1] != dataset_points.GetShape()[1]) {
37  utility::LogError("Point dimensions mismatch {} != {}.",
38  query_points.GetShape()[1],
39  dataset_points.GetShape()[1]);
40  }
41 
42  // Call NNS.
43  core::nns::NearestNeighborSearch nns(dataset_points);
44  nns.KnnIndex();
45  core::Tensor indices;
46  core::Tensor distances;
47  std::tie(indices, distances) = nns.KnnSearch(query_points, knn);
48  return indices.To(core::Int32);
49 }
50 
51 const core::Tensor RadiusSearch(const core::Tensor& query_points,
52  const core::Tensor& dataset_points,
53  const core::Tensor& query_batches,
54  const core::Tensor& dataset_batches,
55  double radius) {
56  // Check dtype.
57  if (dataset_points.GetDtype() != query_points.GetDtype()) {
58  utility::LogError("Point dtype mismatch {} != {}.",
59  dataset_points.GetDtype().ToString(),
60  query_points.GetDtype().ToString());
61  }
62  if (query_batches.GetDtype() != core::Int32) {
63  utility::LogError("query_batches must be of dtype Int32.");
64  }
65  if (dataset_batches.GetDtype() != core::Int32) {
66  utility::LogError("dataset_batches must be of dtype Int32.");
67  }
68 
69  // Check shapes.
70  if (query_points.NumDims() != 2) {
71  utility::LogError("query_points must be of shape {n_query_points, d}.");
72  }
73  if (dataset_points.NumDims() != 2) {
75  "dataset_points must be of shape {n_dataset_points, d}.");
76  }
77  if (query_points.GetShape()[1] != dataset_points.GetShape()[1]) {
78  utility::LogError("Point dimensions mismatch {} != {}.",
79  query_points.GetShape()[1],
80  dataset_points.GetShape()[1]);
81  }
82  if (query_batches.NumDims() != 1) {
83  utility::LogError("query_batches must be of shape {n_batches,}.");
84  }
85  if (dataset_batches.NumDims() != 1) {
86  utility::LogError("dataset_batches must be of shape {n_batches,}.");
87  }
88  if (query_batches.GetShape()[0] != dataset_batches.GetShape()[0]) {
89  utility::LogError("Number of batches lengths not the same: {} != {}.",
90  query_batches.GetShape()[0],
91  dataset_batches.GetShape()[0]);
92  }
93  int64_t num_batches = query_batches.GetShape()[0];
94 
95  // Check consistentency of batch sizes with total number of points.
96  if (query_batches.Sum({0}).Item<int32_t>() != query_points.GetShape()[0]) {
98  "query_batches is not consistent with query_points: {} != {}.",
99  query_batches.Sum({0}).Item<int32_t>(),
100  query_points.GetShape()[0]);
101  }
102  if (dataset_batches.Sum({0}).Item<int32_t>() !=
103  dataset_points.GetShape()[0]) {
105  "dataset_batches is not consistent with dataset_points: {} != "
106  "{}.",
107  dataset_batches.Sum({0}).Item<int32_t>(),
108  dataset_points.GetShape()[0]);
109  }
110  int64_t num_query_points = query_points.GetShape()[0];
111 
112  // Call radius search for each batch.
113  std::vector<core::Tensor> batched_indices(num_batches);
114  std::vector<core::Tensor> batched_num_neighbors(num_batches);
115 
116  // Calculate prefix-sum.
117  std::vector<int32_t> query_prefix_indices(num_batches + 1, 0);
118  std::vector<int32_t> dataset_prefix_indices(num_batches + 1, 0);
119 
120  const int32_t* query_batch_flat =
121  static_cast<const int32_t*>(query_batches.GetDataPtr());
122  const int32_t* dataset_batch_flat =
123  static_cast<const int32_t*>(dataset_batches.GetDataPtr());
124 
125  // TODO: implement Cumsum function in Tensor.
126  std::partial_sum(query_batch_flat, query_batch_flat + num_batches,
127  query_prefix_indices.data() + 1);
128  std::partial_sum(dataset_batch_flat, dataset_batch_flat + num_batches,
129  dataset_prefix_indices.data() + 1);
130 
131  // Parallelization is applied point-wise in NanoFlannIndex.
132  for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
133  core::Tensor current_query_points =
134  query_points.Slice(0, query_prefix_indices[batch_idx],
135  query_prefix_indices[batch_idx + 1]);
136 
137  core::Tensor current_dataset_points =
138  dataset_points.Slice(0, dataset_prefix_indices[batch_idx],
139  dataset_prefix_indices[batch_idx + 1]);
140 
141  // Call radius search.
142  core::nns::NearestNeighborSearch nns(current_dataset_points);
143  nns.FixedRadiusIndex();
144  core::Tensor indices;
145  core::Tensor distances;
146  core::Tensor neighbors_row_splits;
147  std::tie(indices, distances, neighbors_row_splits) =
148  nns.FixedRadiusSearch(current_query_points, radius);
149  batched_indices[batch_idx] = indices;
150  int64_t current_num_query_points = current_query_points.GetShape()[0];
151  core::Tensor num_neighbors =
152  neighbors_row_splits.Slice(0, 1, current_num_query_points + 1)
153  .Sub(neighbors_row_splits.Slice(
154  0, 0, current_num_query_points))
155  .To(core::Int64);
156  batched_num_neighbors[batch_idx] = num_neighbors;
157  }
158 
159  // Find global maximum number of neighbors.
160  int64_t max_num_neighbors = 0;
161  for (const auto& num_neighbors : batched_num_neighbors) {
162  max_num_neighbors = std::max(num_neighbors.Max({0}).Item<int64_t>(),
163  max_num_neighbors);
164  }
165 
166  // Convert to the required output format. Pad with -1.
168  {num_query_points, max_num_neighbors}, -1, core::Int64);
169 
170 #pragma omp parallel for schedule(static) \
171  num_threads(utility::EstimateMaxThreads())
172  for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
173  int32_t result_start_idx = query_prefix_indices[batch_idx];
174  int32_t result_end_idx = query_prefix_indices[batch_idx + 1];
175 
176  core::Tensor indices = batched_indices[batch_idx].Add(
177  dataset_prefix_indices[batch_idx]);
178  core::Tensor num_neighbors = batched_num_neighbors[batch_idx];
179 
180  // Sanity check.
181  int64_t batch_size =
182  result_end_idx - result_start_idx; // Exclusive result_end_idx.
183  if (num_neighbors.GetShape()[0] != batch_size) {
185  "Sanity check failed, batch_id {}: {} != batchsize {}.",
186  batch_idx, num_neighbors.GetShape()[0], batch_size);
187  }
188 
189  int64_t indices_start_idx = 0;
190  for (int64_t i = 0; i < batch_size; i++) {
191  int64_t num_neighbor = num_neighbors[i].Item<int64_t>();
192  core::Tensor result_slice = result.Slice(0, result_start_idx + i,
193  result_start_idx + i + 1)
194  .Slice(1, 0, num_neighbor);
195  core::Tensor indices_slice = indices.Slice(
196  0, indices_start_idx, indices_start_idx + num_neighbor);
197  result_slice.AsRvalue() = indices_slice.View({1, num_neighbor});
198  indices_start_idx += num_neighbor;
199  }
200  }
201 
202  return result.To(core::Int32);
203 }
204 } // namespace contrib
205 } // namespace ml
206 } // namespace cloudViewer
core::Tensor result
Definition: VtkUtils.cpp:76
std::string ToString() const
Definition: Dtype.h:65
int64_t NumDims() const
Definition: Tensor.h:1172
Tensor Sum(const SizeVector &dims, bool keepdim=false) const
Definition: Tensor.cpp:1240
Dtype GetDtype() const
Definition: Tensor.h:1164
Tensor Sub(const Tensor &value) const
Substracts a tensor and returns the resulting tensor.
Definition: Tensor.cpp:1133
Tensor Add(const Tensor &value) const
Adds a tensor and returns the resulting tensor.
Definition: Tensor.cpp:1097
Tensor View(const SizeVector &dst_shape) const
Definition: Tensor.cpp:721
SizeVector GetShape() const
Definition: Tensor.h:1127
Tensor Slice(int64_t dim, int64_t start, int64_t stop, int64_t step=1) const
Definition: Tensor.cpp:857
Tensor To(Dtype dtype, bool copy=false) const
Definition: Tensor.cpp:739
static Tensor Full(const SizeVector &shape, T fill_value, Dtype dtype, const Device &device=Device("CPU:0"))
Create a tensor fill with specified value.
Definition: Tensor.h:253
A Class for nearest neighbor search.
std::tuple< Tensor, Tensor, Tensor > FixedRadiusSearch(const Tensor &query_points, double radius, bool sort=true)
bool FixedRadiusIndex(utility::optional< double > radius={})
std::pair< Tensor, Tensor > KnnSearch(const Tensor &query_points, int knn)
#define LogError(...)
Definition: Logging.h:60
int max(int a, int b)
Definition: cutil_math.h:48
const Dtype Int64
Definition: Dtype.cpp:47
const Dtype Int32
Definition: Dtype.cpp:46
const core::Tensor RadiusSearch(const core::Tensor &query_points, const core::Tensor &dataset_points, const core::Tensor &query_batches, const core::Tensor &dataset_batches, double radius)
Definition: contrib_nns.cpp:51
const core::Tensor KnnSearch(const core::Tensor &query_points, const core::Tensor &dataset_points, int knn)
Definition: contrib_nns.cpp:18
Generic file read and write utility for python interface.