18 const std::vector<Tensor>& index_tensors) {
19 bool index_dim_started =
false;
20 bool index_dim_ended =
false;
21 for (
const Tensor& index_tensor : index_tensors) {
22 if (index_tensor.NumDims() == 0) {
24 if (index_dim_started) {
25 index_dim_ended =
true;
29 if (index_dim_ended) {
32 if (!index_dim_started) {
33 index_dim_started =
true;
40 std::pair<Tensor, std::vector<Tensor>>
42 const Tensor& tensor,
const std::vector<Tensor>& index_tensors) {
43 int64_t ndims = tensor.
NumDims();
44 std::vector<int64_t> permutation;
45 std::vector<Tensor> permuted_index_tensors;
46 for (int64_t i = 0; i < ndims; ++i) {
47 if (index_tensors[i].NumDims() != 0) {
48 permutation.push_back(i);
49 permuted_index_tensors.emplace_back(index_tensors[i]);
52 for (int64_t i = 0; i < ndims; ++i) {
53 if (index_tensors[i].NumDims() == 0) {
54 permutation.push_back(i);
55 permuted_index_tensors.emplace_back(index_tensors[i]);
59 std::move(permuted_index_tensors));
64 const std::vector<Tensor>& index_tensors) {
66 for (
const Tensor& index_tensor : index_tensors) {
67 if (index_tensor.NumDims() != 0) {
69 replacement_shape, index_tensor.GetShape());
73 std::vector<Tensor> expanded_tensors;
74 for (
const Tensor& index_tensor : index_tensors) {
75 if (index_tensor.NumDims() == 0) {
76 expanded_tensors.push_back(index_tensor);
78 expanded_tensors.push_back(index_tensor.Expand(replacement_shape));
91 int64_t end = dims_before + dims_indexed;
95 replacement_shape.
end());
96 strides.
insert(strides.
begin() + dims_before, replacement_shape.
size(), 0);
101 const Tensor& index_tensor, int64_t dims_before, int64_t dims_after) {
105 new_shape.
begin() + dims_before);
114 "Number of index_tensors {} exceeds tensor dimension "
125 "Index tensor must have Int64 dtype, but {} was used.",
126 index_tensor.GetDtype().ToString());
137 Tensor empty_index_tensor =
140 for (int64_t i = 0; i < num_omitted_dims; ++i) {
147 if (index_tensor.NumDims() == 0) {
148 index_tensor.Fill(0);
181 int64_t dims_before = 0;
182 int64_t dims_after = 0;
183 int64_t dims_indexed = 0;
184 bool replacement_shape_inserted =
false;
187 if (dims_indexed == 0) {
194 if (!replacement_shape_inserted) {
196 replacement_shape.
begin(),
197 replacement_shape.
end());
198 replacement_shape_inserted =
true;
211 auto contains_zero = [](
const SizeVector& vals) ->
bool {
212 return std::any_of(vals.begin(), vals.end(),
213 [](int64_t val) { return val == 0; });
215 if (contains_zero(
indexed_shape_) && !contains_zero(replacement_shape)) {
225 dims_before, dims_after);
231 const std::vector<Tensor>& index_tensors) {
232 std::vector<Tensor> res_index_tensors;
233 for (
const Tensor& index_tensor : index_tensors) {
235 std::vector<Tensor> non_zero_indices = index_tensor.NonZeroNumpy();
236 res_index_tensors.insert(res_index_tensors.end(),
237 non_zero_indices.begin(),
238 non_zero_indices.end());
240 res_index_tensors.push_back(index_tensor);
243 return res_index_tensors;
std::vector< Tensor > index_tensors_
The processed index tensors.
SizeVector indexed_strides_
void RunPreprocess()
Preprocess tensor and index tensors.
SizeVector indexed_shape_
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
SizeVector output_shape_
Output shape.
iterator erase(const_iterator CI)
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
Tensor AsStrided(const SizeVector &new_shape, const SizeVector &new_strides) const
Create a Tensor view of specified shape and strides. The underlying buffer and data_ptr offsets remai...
Tensor Permute(const SizeVector &dims) const
Permute (dimension shuffle) the Tensor, returns a view.
SizeVector GetStrides() const
int64_t GetStride(int64_t dim) const
Device GetDevice() const override
Tensor Reshape(const SizeVector &dst_shape) const
SizeVector GetShape() const
SizeVector BroadcastedShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns the broadcasted shape of two shapes.
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
Generic file read and write utility for python interface.