ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
IndexGetSet.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <Logging.h>
11 
12 #include "cloudViewer/core/Dtype.h"
17 
18 namespace cloudViewer {
19 namespace core {
20 namespace kernel {
21 
22 void IndexGet(const Tensor& src,
23  Tensor& dst,
24  const std::vector<Tensor>& index_tensors,
25  const SizeVector& indexed_shape,
26  const SizeVector& indexed_strides) {
27  // index_tensors has been preprocessed to be on the same device as src,
28  // however, dst may be in a different device.
29  if (dst.GetDevice() != src.GetDevice()) {
30  Tensor dst_same_device(dst.GetShape(), dst.GetDtype(), src.GetDevice());
31  IndexGet(src, dst_same_device, index_tensors, indexed_shape,
32  indexed_strides);
33  dst.CopyFrom(dst_same_device);
34  return;
35  }
36 
37  if (src.IsCPU()) {
38  IndexGetCPU(src, dst, index_tensors, indexed_shape, indexed_strides);
39  } else if (src.IsSYCL()) {
40 #ifdef BUILD_SYCL_MODULE
41  IndexGetSYCL(src, dst, index_tensors, indexed_shape, indexed_strides);
42 #endif
43  } else if (src.IsCUDA()) {
44 #ifdef BUILD_CUDA_MODULE
45  IndexGetCUDA(src, dst, index_tensors, indexed_shape, indexed_strides);
46 #endif
47  } else {
48  utility::LogError("IndexGet: Unimplemented device");
49  }
50 }
51 
52 void IndexSet(const Tensor& src,
53  Tensor& dst,
54  const std::vector<Tensor>& index_tensors,
55  const SizeVector& indexed_shape,
56  const SizeVector& indexed_strides) {
57  // index_tensors has been preprocessed to be on the same device as dst,
58  // however, src may be on a different device.
59  Tensor src_same_device = src.To(dst.GetDevice());
60 
61  if (dst.IsCPU()) {
62  IndexSetCPU(src_same_device, dst, index_tensors, indexed_shape,
63  indexed_strides);
64  } else if (dst.IsSYCL()) {
65 #ifdef BUILD_SYCL_MODULE
66  IndexSetSYCL(src_same_device, dst, index_tensors, indexed_shape,
67  indexed_strides);
68 #endif
69  } else if (dst.IsCUDA()) {
70 #ifdef BUILD_CUDA_MODULE
71  IndexSetCUDA(src_same_device, dst, index_tensors, indexed_shape,
72  indexed_strides);
73 #endif
74  } else {
75  utility::LogError("IndexSet: Unimplemented device");
76  }
77 }
78 
79 } // namespace kernel
80 } // namespace core
81 } // namespace cloudViewer
bool IsCUDA() const
Definition: Device.h:99
bool IsCPU() const
Definition: Device.h:95
void CopyFrom(const Tensor &other)
Copy Tensor values to current tensor from the source tensor.
Definition: Tensor.cpp:770
Dtype GetDtype() const
Definition: Tensor.h:1164
Device GetDevice() const override
Definition: Tensor.cpp:1435
SizeVector GetShape() const
Definition: Tensor.h:1127
Tensor To(Dtype dtype, bool copy=false) const
Definition: Tensor.cpp:739
#define LogError(...)
Definition: Logging.h:60
void IndexGetCPU(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
void IndexSet(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
Definition: IndexGetSet.cpp:52
void IndexSetSYCL(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
void IndexGetSYCL(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
void IndexSetCPU(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
void IndexGet(const Tensor &src, Tensor &dst, const std::vector< Tensor > &index_tensors, const SizeVector &indexed_shape, const SizeVector &indexed_strides)
Definition: IndexGetSet.cpp:22
Generic file read and write utility for python interface.