30 const Dtype &index_dtype) {
37 const Dtype &index_dtype) {
39 assert(index_dtype ==
Int32 || index_dtype ==
Int64);
41 if (dataset_points.
NumDims() != 2) {
43 "dataset_points must be 2D matrix, with shape "
44 "{n_dataset_points, d}.");
50 holder_ = impl::BuildKdTree<scalar_t, int_t>(
72 const int64_t num_neighbors =
std::min(
73 static_cast<int64_t
>(
GetDatasetSize()),
static_cast<int64_t
>(knn));
74 const int64_t num_query_points = query_points.
GetShape(0);
82 impl::KnnSearchCPU<scalar_t, int_t>(
88 query_contiguous.
GetShape(1), num_neighbors,
L2,
90 true, output_allocator);
93 indices = indices.
View({num_query_points, num_neighbors});
94 distances = distances.
View({num_query_points, num_neighbors});
111 int64_t num_query_points = query_points.
GetShape(0);
117 if (below_zero.Any().Item<
bool>()) {
121 Tensor indices, distances;
127 impl::RadiusSearchCPU<scalar_t, int_t>(
136 false, sort, output_allocator);
141 return std::make_tuple(indices, distances,
146 const Tensor &query_points,
double radius,
bool sort)
const {
147 const int64_t num_query_points = query_points.
GetShape()[0];
149 std::tuple<Tensor, Tensor, Tensor>
result;
151 Tensor radii(std::vector<scalar_t>(num_query_points, (scalar_t)radius),
152 {num_query_points}, dtype);
159 const Tensor &query_points,
double radius,
int max_knn)
const {
175 int64_t num_query_points = query_points.
GetShape(0);
177 Tensor indices, distances, counts;
182 impl::HybridSearchCPU<scalar_t, int_t>(
187 query_contiguous.
GetShape(1),
static_cast<scalar_t
>(radius),
190 true, output_allocator);
193 {num_query_points, max_knn});
195 {num_query_points, max_knn});
198 return std::make_tuple(indices, distances, counts);
#define DISPATCH_FLOAT_DTYPE_TO_TEMPLATE(DTYPE,...)
#define DISPATCH_FLOAT_INT_DTYPE_TO_TEMPLATE(FDTYPE, IDTYPE,...)
#define AssertTensorDevice(tensor,...)
#define AssertTensorDtype(tensor,...)
#define AssertTensorDtypes(tensor,...)
#define AssertTensorShape(tensor,...)
Tensor Contiguous() const
Tensor View(const SizeVector &dst_shape) const
SizeVector GetShape() const
Tensor To(Dtype dtype, bool copy=false) const
Dtype GetIndexDtype() const
size_t GetDatasetSize() const
std::unique_ptr< NanoFlannIndexHolderBase > holder_
NanoFlannIndex()
Default Constructor.
std::tuple< Tensor, Tensor, Tensor > SearchHybrid(const Tensor &query_points, double radius, int max_knn) const override
std::tuple< Tensor, Tensor, Tensor > SearchRadius(const Tensor &query_points, const Tensor &radii, bool sort=true) const override
bool SetTensorData(const Tensor &dataset_points, const Dtype &index_dtype=core::Int64) override
std::pair< Tensor, Tensor > SearchKnn(const Tensor &query_points, int knn) const override
const Tensor & NeighborsDistance() const
const Tensor & NeighborsIndex() const
const Tensor & NeighborsCount() const
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
constexpr nullopt_t nullopt
Generic file read and write utility for python interface.