![]() |
ACloudViewer
3.9.4
A Modern Library for 3D Data Processing
|
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp. More...
#include <AdvancedIndexing.h>

Public Member Functions | |
| AdvancedIndexPreprocessor (const Tensor &tensor, const std::vector< Tensor > &index_tensors) | |
| Tensor | GetTensor () const |
| std::vector< Tensor > | GetIndexTensors () const |
| SizeVector | GetOutputShape () const |
| SizeVector | GetIndexedShape () const |
| SizeVector | GetIndexedStrides () const |
Static Public Member Functions | |
| static bool | IsIndexSplittedBySlice (const std::vector< Tensor > &index_tensors) |
| static std::pair< Tensor, std::vector< Tensor > > | ShuffleIndexedDimsToFront (const Tensor &tensor, const std::vector< Tensor > &index_tensors) |
| static std::pair< std::vector< Tensor >, SizeVector > | ExpandToCommonShapeExceptZeroDim (const std::vector< Tensor > &index_tensors) |
| static Tensor | RestrideTensor (const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape) |
| static Tensor | RestrideIndexTensor (const Tensor &index_tensor, int64_t dims_before, int64_t dims_after) |
Protected Member Functions | |
| void | RunPreprocess () |
| Preprocess tensor and index tensors. More... | |
Static Protected Member Functions | |
| static std::vector< Tensor > | ExpandBoolTensors (const std::vector< Tensor > &index_tensors) |
| Expand boolean tensor to integer index. More... | |
Protected Attributes | |
| Tensor | tensor_ |
| std::vector< Tensor > | index_tensors_ |
| The processed index tensors. More... | |
| SizeVector | output_shape_ |
| Output shape. More... | |
| SizeVector | indexed_shape_ |
| SizeVector | indexed_strides_ |
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
Definition at line 21 of file AdvancedIndexing.h.
|
inline |
Definition at line 23 of file AdvancedIndexing.h.
References RunPreprocess().
|
staticprotected |
Expand boolean tensor to integer index.
Definition at line 230 of file AdvancedIndexing.cpp.
References cloudViewer::core::Bool.
|
static |
Expand all tensors to the broadcasted shape, 0-dim tensors are ignored. Throws exception if the common broadcasted shape does not exist.
Definition at line 63 of file AdvancedIndexing.cpp.
References cloudViewer::core::shape_util::BroadcastedShape(), and cloudViewer::core::make_pair().
Referenced by RunPreprocess().
|
inline |
Definition at line 37 of file AdvancedIndexing.h.
References indexed_shape_.
Referenced by cloudViewer::core::Tensor::IndexGet(), and cloudViewer::core::Tensor::IndexSet().
|
inline |
Definition at line 39 of file AdvancedIndexing.h.
References indexed_strides_.
Referenced by cloudViewer::core::Tensor::IndexGet(), and cloudViewer::core::Tensor::IndexSet().
|
inline |
Definition at line 31 of file AdvancedIndexing.h.
References index_tensors_.
Referenced by cloudViewer::core::Tensor::IndexGet(), and cloudViewer::core::Tensor::IndexSet().
|
inline |
Definition at line 35 of file AdvancedIndexing.h.
References output_shape_.
Referenced by cloudViewer::core::Tensor::IndexGet().
|
inline |
Definition at line 29 of file AdvancedIndexing.h.
References tensor_.
Referenced by cloudViewer::core::Tensor::IndexGet(), and cloudViewer::core::Tensor::IndexSet().
|
static |
Returns true if the indexed dimension is splitted by (full) slice. E.g. A[[1, 2], :, [1, 2]] returns true A[[1, 2], [1, 2], :] returns false
Definition at line 17 of file AdvancedIndexing.cpp.
Referenced by RunPreprocess().
|
static |
Definition at line 100 of file AdvancedIndexing.cpp.
References cloudViewer::core::SmallVectorTemplateCommon< T, typename >::begin(), copy, cloudViewer::core::SmallVectorTemplateCommon< T, typename >::end(), cloudViewer::core::Tensor::GetShape(), cloudViewer::core::Tensor::NumDims(), and cloudViewer::core::Tensor::Reshape().
Referenced by RunPreprocess().
|
static |
Definition at line 85 of file AdvancedIndexing.cpp.
References cloudViewer::core::Tensor::AsStrided(), cloudViewer::core::SmallVectorTemplateCommon< T, typename >::begin(), cloudViewer::core::SmallVectorTemplateCommon< T, typename >::end(), cloudViewer::core::SmallVectorImpl< T >::erase(), cloudViewer::core::Tensor::GetShape(), cloudViewer::core::Tensor::GetStrides(), cloudViewer::core::SmallVectorImpl< T >::insert(), and cloudViewer::core::SmallVectorBase< Size_T >::size().
Referenced by RunPreprocess().
|
protected |
Preprocess tensor and index tensors.
Definition at line 110 of file AdvancedIndexing.cpp.
References cloudViewer::core::SmallVectorTemplateCommon< T, typename >::begin(), cloudViewer::core::SmallVectorTemplateCommon< T, typename >::end(), ExpandToCommonShapeExceptZeroDim(), cloudViewer::core::Tensor::GetDevice(), cloudViewer::core::Tensor::GetShape(), cloudViewer::core::Tensor::GetStride(), index_tensors_, indexed_shape_, indexed_strides_, cloudViewer::core::SmallVectorImpl< T >::insert(), cloudViewer::core::Int64, IsIndexSplittedBySlice(), LogError, cloudViewer::core::Tensor::NumDims(), output_shape_, cloudViewer::core::SmallVectorTemplateBase< T, bool >::push_back(), RestrideIndexTensor(), RestrideTensor(), ShuffleIndexedDimsToFront(), and tensor_.
Referenced by AdvancedIndexPreprocessor().
|
static |
Shuffle indexed dimensions in front of the slice dimensions for the tensor and index tensors.
Definition at line 41 of file AdvancedIndexing.cpp.
References cloudViewer::core::make_pair(), cloudViewer::core::Tensor::NumDims(), and cloudViewer::core::Tensor::Permute().
Referenced by RunPreprocess().
|
protected |
The processed index tensors.
Definition at line 94 of file AdvancedIndexing.h.
Referenced by GetIndexTensors(), and RunPreprocess().
|
protected |
The shape of the indexed dimensions. See the docstring of RestrideTensor for details.
Definition at line 101 of file AdvancedIndexing.h.
Referenced by GetIndexedShape(), and RunPreprocess().
|
protected |
The strides for indexed dimensions, in element numbers (not byte size). See the docstring of RestrideTensor for details.
Definition at line 105 of file AdvancedIndexing.h.
Referenced by GetIndexedStrides(), and RunPreprocess().
|
protected |
Output shape.
Definition at line 97 of file AdvancedIndexing.h.
Referenced by GetOutputShape(), and RunPreprocess().
|
protected |
The processed tensors being indexed. The tensor still uses the same underlying memory, but it may have been reshaped and restrided.
Definition at line 91 of file AdvancedIndexing.h.
Referenced by GetTensor(), and RunPreprocess().