ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
cloudViewer::core::AdvancedIndexPreprocessor Class Reference

This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp. More...

#include <AdvancedIndexing.h>

Collaboration diagram for cloudViewer::core::AdvancedIndexPreprocessor:

Public Member Functions

 AdvancedIndexPreprocessor (const Tensor &tensor, const std::vector< Tensor > &index_tensors)
 
Tensor GetTensor () const
 
std::vector< TensorGetIndexTensors () const
 
SizeVector GetOutputShape () const
 
SizeVector GetIndexedShape () const
 
SizeVector GetIndexedStrides () const
 

Static Public Member Functions

static bool IsIndexSplittedBySlice (const std::vector< Tensor > &index_tensors)
 
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront (const Tensor &tensor, const std::vector< Tensor > &index_tensors)
 
static std::pair< std::vector< Tensor >, SizeVectorExpandToCommonShapeExceptZeroDim (const std::vector< Tensor > &index_tensors)
 
static Tensor RestrideTensor (const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
 
static Tensor RestrideIndexTensor (const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
 

Protected Member Functions

void RunPreprocess ()
 Preprocess tensor and index tensors. More...
 

Static Protected Member Functions

static std::vector< TensorExpandBoolTensors (const std::vector< Tensor > &index_tensors)
 Expand boolean tensor to integer index. More...
 

Protected Attributes

Tensor tensor_
 
std::vector< Tensorindex_tensors_
 The processed index tensors. More...
 
SizeVector output_shape_
 Output shape. More...
 
SizeVector indexed_shape_
 
SizeVector indexed_strides_
 

Detailed Description

This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.

Definition at line 21 of file AdvancedIndexing.h.

Constructor & Destructor Documentation

◆ AdvancedIndexPreprocessor()

cloudViewer::core::AdvancedIndexPreprocessor::AdvancedIndexPreprocessor ( const Tensor tensor,
const std::vector< Tensor > &  index_tensors 
)
inline

Definition at line 23 of file AdvancedIndexing.h.

References RunPreprocess().

Member Function Documentation

◆ ExpandBoolTensors()

std::vector< Tensor > cloudViewer::core::AdvancedIndexPreprocessor::ExpandBoolTensors ( const std::vector< Tensor > &  index_tensors)
staticprotected

Expand boolean tensor to integer index.

Definition at line 230 of file AdvancedIndexing.cpp.

References cloudViewer::core::Bool.

◆ ExpandToCommonShapeExceptZeroDim()

std::pair< std::vector< Tensor >, SizeVector > cloudViewer::core::AdvancedIndexPreprocessor::ExpandToCommonShapeExceptZeroDim ( const std::vector< Tensor > &  index_tensors)
static

Expand all tensors to the broadcasted shape, 0-dim tensors are ignored. Throws exception if the common broadcasted shape does not exist.

Definition at line 63 of file AdvancedIndexing.cpp.

References cloudViewer::core::shape_util::BroadcastedShape(), and cloudViewer::core::make_pair().

Referenced by RunPreprocess().

◆ GetIndexedShape()

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::GetIndexedShape ( ) const
inline

◆ GetIndexedStrides()

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::GetIndexedStrides ( ) const
inline

◆ GetIndexTensors()

std::vector<Tensor> cloudViewer::core::AdvancedIndexPreprocessor::GetIndexTensors ( ) const
inline

◆ GetOutputShape()

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::GetOutputShape ( ) const
inline

Definition at line 35 of file AdvancedIndexing.h.

References output_shape_.

Referenced by cloudViewer::core::Tensor::IndexGet().

◆ GetTensor()

Tensor cloudViewer::core::AdvancedIndexPreprocessor::GetTensor ( ) const
inline

◆ IsIndexSplittedBySlice()

bool cloudViewer::core::AdvancedIndexPreprocessor::IsIndexSplittedBySlice ( const std::vector< Tensor > &  index_tensors)
static

Returns true if the indexed dimension is splitted by (full) slice. E.g. A[[1, 2], :, [1, 2]] returns true A[[1, 2], [1, 2], :] returns false

Definition at line 17 of file AdvancedIndexing.cpp.

Referenced by RunPreprocess().

◆ RestrideIndexTensor()

Tensor cloudViewer::core::AdvancedIndexPreprocessor::RestrideIndexTensor ( const Tensor index_tensor,
int64_t  dims_before,
int64_t  dims_after 
)
static

◆ RestrideTensor()

◆ RunPreprocess()

◆ ShuffleIndexedDimsToFront()

std::pair< Tensor, std::vector< Tensor > > cloudViewer::core::AdvancedIndexPreprocessor::ShuffleIndexedDimsToFront ( const Tensor tensor,
const std::vector< Tensor > &  index_tensors 
)
static

Shuffle indexed dimensions in front of the slice dimensions for the tensor and index tensors.

Definition at line 41 of file AdvancedIndexing.cpp.

References cloudViewer::core::make_pair(), cloudViewer::core::Tensor::NumDims(), and cloudViewer::core::Tensor::Permute().

Referenced by RunPreprocess().

Member Data Documentation

◆ index_tensors_

std::vector<Tensor> cloudViewer::core::AdvancedIndexPreprocessor::index_tensors_
protected

The processed index tensors.

Definition at line 94 of file AdvancedIndexing.h.

Referenced by GetIndexTensors(), and RunPreprocess().

◆ indexed_shape_

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::indexed_shape_
protected

The shape of the indexed dimensions. See the docstring of RestrideTensor for details.

Definition at line 101 of file AdvancedIndexing.h.

Referenced by GetIndexedShape(), and RunPreprocess().

◆ indexed_strides_

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::indexed_strides_
protected

The strides for indexed dimensions, in element numbers (not byte size). See the docstring of RestrideTensor for details.

Definition at line 105 of file AdvancedIndexing.h.

Referenced by GetIndexedStrides(), and RunPreprocess().

◆ output_shape_

SizeVector cloudViewer::core::AdvancedIndexPreprocessor::output_shape_
protected

Output shape.

Definition at line 97 of file AdvancedIndexing.h.

Referenced by GetOutputShape(), and RunPreprocess().

◆ tensor_

Tensor cloudViewer::core::AdvancedIndexPreprocessor::tensor_
protected

The processed tensors being indexed. The tensor still uses the same underlying memory, but it may have been reshaped and restrided.

Definition at line 91 of file AdvancedIndexing.h.

Referenced by GetTensor(), and RunPreprocess().


The documentation for this class was generated from the following files: