24 const std::vector<Tensor>& index_tensors)
45 const std::vector<Tensor>& index_tensors);
50 const Tensor& tensor,
const std::vector<Tensor>& index_tensors);
54 static std::pair<std::vector<Tensor>,
SizeVector>
87 const std::vector<Tensor>& index_tensors);
123 const std::vector<Tensor>& index_tensors,
128 if (indexed_shape.
size() != indexed_strides.
size()) {
130 "Internal error: indexed_shape's ndim {} does not equal to "
131 "indexed_strides' ndim {}",
132 indexed_shape.
size(), indexed_strides.
size());
137 std::vector<Tensor> inputs;
138 inputs.push_back(src);
139 for (
const Tensor& index_tensor : index_tensors) {
140 if (index_tensor.NumDims() != 0) {
141 inputs.push_back(index_tensor);
149 "Internal error: indexed_shape's ndim {} does not equal to "
150 "indexd_strides' ndim {}",
161 "src's dtype {} is not the same as dst's dtype {}.",
168 int64_t workload_idx)
const {
176 int64_t workload_idx)
const {
187 int64_t index = *(
reinterpret_cast<int64_t*
>(
191 "Index out of bounds.");
#define CLOUDVIEWER_HOST_DEVICE
#define CLOUDVIEWER_ASSERT(...)
This class is based on PyTorch's aten/src/ATen/native/Indexing.cpp.
std::vector< Tensor > index_tensors_
The processed index tensors.
SizeVector indexed_strides_
std::vector< Tensor > GetIndexTensors() const
void RunPreprocess()
Preprocess tensor and index tensors.
SizeVector indexed_shape_
static std::pair< std::vector< Tensor >, SizeVector > ExpandToCommonShapeExceptZeroDim(const std::vector< Tensor > &index_tensors)
static std::vector< Tensor > ExpandBoolTensors(const std::vector< Tensor > &index_tensors)
Expand boolean tensor to integer index.
SizeVector GetIndexedStrides() const
static bool IsIndexSplittedBySlice(const std::vector< Tensor > &index_tensors)
static Tensor RestrideIndexTensor(const Tensor &index_tensor, int64_t dims_before, int64_t dims_after)
static std::pair< Tensor, std::vector< Tensor > > ShuffleIndexedDimsToFront(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
AdvancedIndexPreprocessor(const Tensor &tensor, const std::vector< Tensor > &index_tensors)
SizeVector GetIndexedShape() const
static Tensor RestrideTensor(const Tensor &tensor, int64_t dims_before, int64_t dims_indexed, SizeVector replacement_shape)
SizeVector output_shape_
Output shape.
SizeVector GetOutputShape() const
AdvancedIndexerMode mode_
CLOUDVIEWER_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
CLOUDVIEWER_HOST_DEVICE int64_t GetIndexedOffset(int64_t workload_idx) const
int64_t indexed_strides_[MAX_DIMS]
int64_t NumWorkloads() const
int64_t indexed_shape_[MAX_DIMS]
CLOUDVIEWER_HOST_DEVICE char * GetInputPtr(int64_t workload_idx) const
int64_t element_byte_size_
AdvancedIndexer(const Tensor &src, const Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides, AdvancedIndexerMode mode)
std::string ToString() const
CLOUDVIEWER_HOST_DEVICE char * GetInputPtr(int64_t input_idx, int64_t workload_idx) const
CLOUDVIEWER_HOST_DEVICE char * GetOutputPtr(int64_t workload_idx) const
int64_t NumWorkloads() const
static constexpr int64_t MAX_DIMS
Generic file read and write utility for python interface.