ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
NanoFlannIndex.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <Logging.h>
11 
18 
19 namespace cloudViewer {
20 namespace core {
21 namespace nns {
22 
24 
25 NanoFlannIndex::NanoFlannIndex(const Tensor &dataset_points) {
26  SetTensorData(dataset_points);
27 };
28 
29 NanoFlannIndex::NanoFlannIndex(const Tensor &dataset_points,
30  const Dtype &index_dtype) {
31  SetTensorData(dataset_points, index_dtype);
32 };
33 
35 
36 bool NanoFlannIndex::SetTensorData(const Tensor &dataset_points,
37  const Dtype &index_dtype) {
38  AssertTensorDtypes(dataset_points, {Float32, Float64});
39  assert(index_dtype == Int32 || index_dtype == Int64);
40 
41  if (dataset_points.NumDims() != 2) {
43  "dataset_points must be 2D matrix, with shape "
44  "{n_dataset_points, d}.");
45  }
46 
47  dataset_points_ = dataset_points.Contiguous();
48  index_dtype_ = index_dtype;
50  holder_ = impl::BuildKdTree<scalar_t, int_t>(
52  dataset_points_.GetDataPtr<scalar_t>(),
53  dataset_points_.GetShape(1), /* metric */ L2);
54  });
55  return true;
56 };
57 
58 std::pair<Tensor, Tensor> NanoFlannIndex::SearchKnn(const Tensor &query_points,
59  int knn) const {
60  const Dtype dtype = GetDtype();
61  const Device device = GetDevice();
62  const Dtype index_dtype = GetIndexDtype();
63 
64  core::AssertTensorDevice(query_points, device);
65  core::AssertTensorDtype(query_points, dtype);
67 
68  if (knn <= 0) {
69  utility::LogError("knn should be larger than 0.");
70  }
71 
72  const int64_t num_neighbors = std::min(
73  static_cast<int64_t>(GetDatasetSize()), static_cast<int64_t>(knn));
74  const int64_t num_query_points = query_points.GetShape(0);
75 
76  Tensor indices, distances;
77  Tensor neighbors_row_splits = Tensor({num_query_points + 1}, Int64);
78  DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(dtype, index_dtype, [&]() {
79  const Tensor query_contiguous = query_points.Contiguous();
80  NeighborSearchAllocator<scalar_t, int_t> output_allocator(device);
81 
82  impl::KnnSearchCPU<scalar_t, int_t>(
83  holder_.get(), neighbors_row_splits.GetDataPtr<int64_t>(),
85  dataset_points_.GetDataPtr<scalar_t>(),
86  query_contiguous.GetShape(0),
87  query_contiguous.GetDataPtr<scalar_t>(),
88  query_contiguous.GetShape(1), num_neighbors, /* metric */ L2,
89  /* ignore_query_point */ false,
90  /* return_distances */ true, output_allocator);
91  indices = output_allocator.NeighborsIndex();
92  distances = output_allocator.NeighborsDistance();
93  indices = indices.View({num_query_points, num_neighbors});
94  distances = distances.View({num_query_points, num_neighbors});
95  });
96  return std::make_pair(indices, distances);
97 };
98 
99 std::tuple<Tensor, Tensor, Tensor> NanoFlannIndex::SearchRadius(
100  const Tensor &query_points, const Tensor &radii, bool sort) const {
101  const Dtype dtype = GetDtype();
102  const Device device = GetDevice();
103  const Dtype index_dtype = GetIndexDtype();
104 
105  core::AssertTensorDevice(query_points, device);
107  core::AssertTensorDtype(query_points, dtype);
109 
110  // Check shapes.
111  int64_t num_query_points = query_points.GetShape(0);
112  AssertTensorShape(query_points, {utility::nullopt, GetDimension()});
113  AssertTensorShape(radii, {num_query_points});
114 
115  // Check if the radii has negative values.
116  Tensor below_zero = radii.Le(0);
117  if (below_zero.Any().Item<bool>()) {
118  utility::LogError("radius should be larger than 0.");
119  }
120 
121  Tensor indices, distances;
122  Tensor neighbors_row_splits = Tensor({num_query_points + 1}, Int64);
123  DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(dtype, index_dtype, [&]() {
124  const Tensor query_contiguous = query_points.Contiguous();
125  NeighborSearchAllocator<scalar_t, int_t> output_allocator(device);
126 
127  impl::RadiusSearchCPU<scalar_t, int_t>(
128  holder_.get(), neighbors_row_splits.GetDataPtr<int64_t>(),
130  dataset_points_.GetDataPtr<scalar_t>(),
131  query_contiguous.GetShape(0),
132  query_contiguous.GetDataPtr<scalar_t>(),
133  query_contiguous.GetShape(1), radii.GetDataPtr<scalar_t>(),
134  /* metric */ L2,
135  /* ignore_query_point */ false, /* return_distances */ true,
136  /* normalize_distances */ false, sort, output_allocator);
137  indices = output_allocator.NeighborsIndex();
138  distances = output_allocator.NeighborsDistance();
139  });
140 
141  return std::make_tuple(indices, distances,
142  neighbors_row_splits.To(index_dtype_));
143 };
144 
145 std::tuple<Tensor, Tensor, Tensor> NanoFlannIndex::SearchRadius(
146  const Tensor &query_points, double radius, bool sort) const {
147  const int64_t num_query_points = query_points.GetShape()[0];
148  const Dtype dtype = GetDtype();
149  std::tuple<Tensor, Tensor, Tensor> result;
150  DISPATCH_FLOAT_DTYPE_TO_TEMPLATE(dtype, [&]() {
151  Tensor radii(std::vector<scalar_t>(num_query_points, (scalar_t)radius),
152  {num_query_points}, dtype);
153  result = SearchRadius(query_points, radii, sort);
154  });
155  return result;
156 };
157 
158 std::tuple<Tensor, Tensor, Tensor> NanoFlannIndex::SearchHybrid(
159  const Tensor &query_points, double radius, int max_knn) const {
160  const Device device = GetDevice();
161  const Dtype dtype = GetDtype();
162  const Dtype index_dtype = GetIndexDtype();
163 
164  AssertTensorDevice(query_points, device);
165  AssertTensorDtype(query_points, dtype);
166  AssertTensorShape(query_points, {utility::nullopt, GetDimension()});
167 
168  if (max_knn <= 0) {
169  utility::LogError("max_knn should be larger than 0.");
170  }
171  if (radius <= 0) {
172  utility::LogError("radius should be larger than 0.");
173  }
174 
175  int64_t num_query_points = query_points.GetShape(0);
176 
177  Tensor indices, distances, counts;
178  DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(dtype, index_dtype, [&]() {
179  const Tensor query_contiguous = query_points.Contiguous();
180  NeighborSearchAllocator<scalar_t, int_t> output_allocator(device);
181 
182  impl::HybridSearchCPU<scalar_t, int_t>(
183  holder_.get(), dataset_points_.GetShape(0),
184  dataset_points_.GetDataPtr<scalar_t>(),
185  query_contiguous.GetShape(0),
186  query_contiguous.GetDataPtr<scalar_t>(),
187  query_contiguous.GetShape(1), static_cast<scalar_t>(radius),
188  max_knn,
189  /* metric*/ L2, /* ignore_query_point */ false,
190  /* return_distances */ true, output_allocator);
191 
192  indices = output_allocator.NeighborsIndex().View(
193  {num_query_points, max_knn});
194  distances = output_allocator.NeighborsDistance().View(
195  {num_query_points, max_knn});
196  counts = output_allocator.NeighborsCount();
197  });
198  return std::make_tuple(indices, distances, counts);
199 }
200 
201 } // namespace nns
202 } // namespace core
203 } // namespace cloudViewer
#define DISPATCH_FLOAT_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: Dispatch.h:78
#define DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(FDTYPE, IDTYPE,...)
Definition: Dispatch.h:91
#define AssertTensorDevice(tensor,...)
Definition: TensorCheck.h:45
#define AssertTensorDtype(tensor,...)
Definition: TensorCheck.h:21
#define AssertTensorDtypes(tensor,...)
Definition: TensorCheck.h:33
#define AssertTensorShape(tensor,...)
Definition: TensorCheck.h:61
core::Tensor result
Definition: VtkUtils.cpp:76
Tensor Contiguous() const
Definition: Tensor.cpp:772
int64_t NumDims() const
Definition: Tensor.h:1172
Tensor View(const SizeVector &dst_shape) const
Definition: Tensor.cpp:721
SizeVector GetShape() const
Definition: Tensor.h:1127
Tensor To(Dtype dtype, bool copy=false) const
Definition: Tensor.cpp:739
std::unique_ptr< NanoFlannIndexHolderBase > holder_
std::tuple< Tensor, Tensor, Tensor > SearchHybrid(const Tensor &query_points, double radius, int max_knn) const override
std::tuple< Tensor, Tensor, Tensor > SearchRadius(const Tensor &query_points, const Tensor &radii, bool sort=true) const override
bool SetTensorData(const Tensor &dataset_points, const Dtype &index_dtype=core::Int64) override
std::pair< Tensor, Tensor > SearchKnn(const Tensor &query_points, int knn) const override
#define LogError(...)
Definition: Logging.h:60
int min(int a, int b)
Definition: cutil_math.h:53
const Dtype Int64
Definition: Dtype.cpp:47
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
Definition: SlabTraits.h:49
const Dtype Float64
Definition: Dtype.cpp:43
const Dtype Int32
Definition: Dtype.cpp:46
const Dtype Float32
Definition: Dtype.cpp:42
constexpr nullopt_t nullopt
Definition: Optional.h:136
Generic file read and write utility for python interface.
std::vector< PointCoordinateType > radii
Definition: qM3C2Tools.cpp:42