ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
KnnIndex.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <Logging.h>
11 
15 
16 namespace cloudViewer {
17 namespace core {
18 namespace nns {
19 
21 
22 KnnIndex::KnnIndex(const Tensor& dataset_points) {
23  SetTensorData(dataset_points);
24 }
25 
26 KnnIndex::KnnIndex(const Tensor& dataset_points, const Dtype& index_dtype) {
27  SetTensorData(dataset_points, index_dtype);
28 }
29 
31 
32 bool KnnIndex::SetTensorData(const Tensor& dataset_points,
33  const Dtype& index_dtype) {
34  int64_t num_dataset_points = dataset_points.GetShape(0);
35  Tensor points_row_splits = Tensor::Init<int64_t>({0, num_dataset_points});
36  return SetTensorData(dataset_points, points_row_splits, index_dtype);
37 }
38 
39 bool KnnIndex::SetTensorData(const Tensor& dataset_points,
40  const Tensor& points_row_splits,
41  const Dtype& index_dtype) {
42  AssertTensorDtypes(dataset_points, {Float32, Float64});
43  assert(index_dtype == Int32 || index_dtype == Int64);
44  AssertTensorDevice(points_row_splits, Device("CPU:0"));
45  AssertTensorDtype(points_row_splits, Int64);
46 
47  if (dataset_points.NumDims() != 2) {
49  "dataset_points must be 2D matrix with shape "
50  "{n_dataset_points, d}.");
51  }
52  if (dataset_points.GetShape(0) <= 0 || dataset_points.GetShape(1) <= 0) {
53  utility::LogError("Failed due to no data.");
54  }
55  if (dataset_points.GetShape(0) != points_row_splits[-1].Item<int64_t>()) {
57  "dataset_points and points_row_splits have incompatible "
58  "shapes.");
59  }
60 
61  if (dataset_points.IsCUDA()) {
62 #ifdef BUILD_CUDA_MODULE
63  dataset_points_ = dataset_points.Contiguous();
64  points_row_splits_ = points_row_splits.Contiguous();
65  index_dtype_ = index_dtype;
66  return true;
67 #else
69  "GPU Tensor is not supported when -DBUILD_CUDA_MODULE=OFF. "
70  "Please recompile Open3d With -DBUILD_CUDA_MODULE=ON.");
71 #endif
72  } else {
74  "CPU Tensor is not supported in KnnIndex. Please use "
75  "NanoFlannIndex instead.");
76  }
77  return false;
78 }
79 
80 std::pair<Tensor, Tensor> KnnIndex::SearchKnn(const Tensor& query_points,
81  int knn) const {
82  int64_t num_query_points = query_points.GetShape(0);
83  Tensor queries_row_splits = Tensor::Init<int64_t>({0, num_query_points});
84  return SearchKnn(query_points, queries_row_splits, knn);
85 }
86 
87 std::pair<Tensor, Tensor> KnnIndex::SearchKnn(const Tensor& query_points,
88  const Tensor& queries_row_splits,
89  int knn) const {
90  const Dtype dtype = GetDtype();
91  const Device device = GetDevice();
92 
93  // Only Float32, Float64 type dataset_points are supported.
94  AssertTensorDtype(query_points, dtype);
95  AssertTensorDevice(query_points, device);
97  AssertTensorDtype(queries_row_splits, Int64);
98  AssertTensorDevice(queries_row_splits, Device("CPU:0"));
99 
100  if (query_points.GetShape(0) != queries_row_splits[-1].Item<int64_t>()) {
102  "query_points and queries_row_splits have incompatible "
103  "shapes.");
104  }
105  if (knn <= 0) {
106  utility::LogError("knn should be larger than 0.");
107  }
108 
109  Tensor query_points_ = query_points.Contiguous();
110  Tensor queries_row_splits_ = queries_row_splits.Contiguous();
111 
112  Tensor neighbors_index, neighbors_distance;
113  Tensor neighbors_row_splits =
114  Tensor::Empty({query_points.GetShape(0) + 1}, Int64);
115 
116 #define KNN_PARAMETERS \
117  dataset_points_, points_row_splits_, query_points_, queries_row_splits_, \
118  knn, neighbors_index, neighbors_row_splits, neighbors_distance
119 
120  if (device.IsCUDA()) {
121 #ifdef BUILD_CUDA_MODULE
122  const Dtype index_dtype = GetIndexDtype();
123  DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(dtype, index_dtype, [&]() {
124  KnnSearchCUDA<scalar_t, int_t>(KNN_PARAMETERS);
125  });
126 #else
128  "-DBUILD_CUDA_MODULE=OFF. Please compile Open3d with "
129  "-DBUILD_CUDA_MODULE=ON.");
130 #endif
131  } else {
133  "-DBUILD_CUDA_MODULE=OFF. Please compile Open3d with "
134  "-DBUILD_CUDA_MODULE=ON.");
135  }
136  return std::make_pair(neighbors_index, neighbors_distance);
137 }
138 
139 } // namespace nns
140 } // namespace core
141 } // namespace cloudViewer
#define DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(FDTYPE, IDTYPE,...)
Definition: Dispatch.h:91
#define KNN_PARAMETERS
#define AssertTensorDevice(tensor,...)
Definition: TensorCheck.h:45
#define AssertTensorDtype(tensor,...)
Definition: TensorCheck.h:21
#define AssertTensorDtypes(tensor,...)
Definition: TensorCheck.h:33
#define AssertTensorShape(tensor,...)
Definition: TensorCheck.h:61
bool IsCUDA() const
Returns true iff device type is CUDA.
Definition: Device.h:49
bool IsCUDA() const
Definition: Device.h:99
Tensor Contiguous() const
Definition: Tensor.cpp:772
int64_t NumDims() const
Definition: Tensor.h:1172
static Tensor Empty(const SizeVector &shape, Dtype dtype, const Device &device=Device("CPU:0"))
Create a tensor with uninitialized values.
Definition: Tensor.cpp:400
SizeVector GetShape() const
Definition: Tensor.h:1127
bool SetTensorData(const Tensor &dataset_points, const Dtype &index_dtype=core::Int64) override
Definition: KnnIndex.cpp:32
std::pair< Tensor, Tensor > SearchKnn(const Tensor &query_points, int knn) const override
Definition: KnnIndex.cpp:80
#define LogError(...)
Definition: Logging.h:60
const Dtype Int64
Definition: Dtype.cpp:47
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
Definition: SlabTraits.h:49
const Dtype Float64
Definition: Dtype.cpp:43
const Dtype Int32
Definition: Dtype.cpp:46
const Dtype Float32
Definition: Dtype.cpp:42
constexpr nullopt_t nullopt
Definition: Optional.h:136
Generic file read and write utility for python interface.