29 if (query_points.
NumDims() != 2) {
32 if (dataset_points.
NumDims() != 2) {
34 "dataset_points must be of shape {n_dataset_points, d}.");
47 std::tie(indices, distances) = nns.
KnnSearch(query_points, knn);
70 if (query_points.
NumDims() != 2) {
73 if (dataset_points.
NumDims() != 2) {
75 "dataset_points must be of shape {n_dataset_points, d}.");
82 if (query_batches.
NumDims() != 1) {
85 if (dataset_batches.
NumDims() != 1) {
93 int64_t num_batches = query_batches.
GetShape()[0];
96 if (query_batches.
Sum({0}).
Item<int32_t>() != query_points.
GetShape()[0]) {
98 "query_batches is not consistent with query_points: {} != {}.",
99 query_batches.
Sum({0}).
Item<int32_t>(),
102 if (dataset_batches.
Sum({0}).
Item<int32_t>() !=
105 "dataset_batches is not consistent with dataset_points: {} != "
107 dataset_batches.
Sum({0}).
Item<int32_t>(),
110 int64_t num_query_points = query_points.
GetShape()[0];
113 std::vector<core::Tensor> batched_indices(num_batches);
114 std::vector<core::Tensor> batched_num_neighbors(num_batches);
117 std::vector<int32_t> query_prefix_indices(num_batches + 1, 0);
118 std::vector<int32_t> dataset_prefix_indices(num_batches + 1, 0);
120 const int32_t* query_batch_flat =
121 static_cast<const int32_t*
>(query_batches.
GetDataPtr());
122 const int32_t* dataset_batch_flat =
123 static_cast<const int32_t*
>(dataset_batches.
GetDataPtr());
126 std::partial_sum(query_batch_flat, query_batch_flat + num_batches,
127 query_prefix_indices.data() + 1);
128 std::partial_sum(dataset_batch_flat, dataset_batch_flat + num_batches,
129 dataset_prefix_indices.data() + 1);
132 for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
134 query_points.
Slice(0, query_prefix_indices[batch_idx],
135 query_prefix_indices[batch_idx + 1]);
138 dataset_points.
Slice(0, dataset_prefix_indices[batch_idx],
139 dataset_prefix_indices[batch_idx + 1]);
147 std::tie(indices, distances, neighbors_row_splits) =
149 batched_indices[batch_idx] = indices;
150 int64_t current_num_query_points = current_query_points.
GetShape()[0];
152 neighbors_row_splits.
Slice(0, 1, current_num_query_points + 1)
154 0, 0, current_num_query_points))
156 batched_num_neighbors[batch_idx] = num_neighbors;
160 int64_t max_num_neighbors = 0;
161 for (
const auto& num_neighbors : batched_num_neighbors) {
162 max_num_neighbors =
std::max(num_neighbors.Max({0}).Item<int64_t>(),
168 {num_query_points, max_num_neighbors}, -1,
core::Int64);
170 #pragma omp parallel for schedule(static) \
171 num_threads(utility::EstimateMaxThreads())
172 for (int64_t batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
173 int32_t result_start_idx = query_prefix_indices[batch_idx];
174 int32_t result_end_idx = query_prefix_indices[batch_idx + 1];
177 dataset_prefix_indices[batch_idx]);
178 core::Tensor num_neighbors = batched_num_neighbors[batch_idx];
182 result_end_idx - result_start_idx;
183 if (num_neighbors.
GetShape()[0] != batch_size) {
185 "Sanity check failed, batch_id {}: {} != batchsize {}.",
186 batch_idx, num_neighbors.
GetShape()[0], batch_size);
189 int64_t indices_start_idx = 0;
190 for (int64_t i = 0; i < batch_size; i++) {
191 int64_t num_neighbor = num_neighbors[i].
Item<int64_t>();
193 result_start_idx + i + 1)
194 .Slice(1, 0, num_neighbor);
196 0, indices_start_idx, indices_start_idx + num_neighbor);
197 result_slice.
AsRvalue() = indices_slice.
View({1, num_neighbor});
198 indices_start_idx += num_neighbor;
std::string ToString() const
Tensor Sum(const SizeVector &dims, bool keepdim=false) const
Tensor Sub(const Tensor &value) const
Substracts a tensor and returns the resulting tensor.
Tensor Add(const Tensor &value) const
Adds a tensor and returns the resulting tensor.
Tensor View(const SizeVector &dst_shape) const
SizeVector GetShape() const
Tensor Slice(int64_t dim, int64_t start, int64_t stop, int64_t step=1) const
Tensor To(Dtype dtype, bool copy=false) const
static Tensor Full(const SizeVector &shape, T fill_value, Dtype dtype, const Device &device=Device("CPU:0"))
Create a tensor fill with specified value.
A Class for nearest neighbor search.
std::tuple< Tensor, Tensor, Tensor > FixedRadiusSearch(const Tensor &query_points, double radius, bool sort=true)
bool FixedRadiusIndex(utility::optional< double > radius={})
std::pair< Tensor, Tensor > KnnSearch(const Tensor &query_points, int knn)
const core::Tensor RadiusSearch(const core::Tensor &query_points, const core::Tensor &dataset_points, const core::Tensor &query_batches, const core::Tensor &dataset_batches, double radius)
const core::Tensor KnnSearch(const core::Tensor &query_points, const core::Tensor &dataset_points, int knn)
Generic file read and write utility for python interface.