ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
Matmul.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <unordered_map>
11 
13 
14 namespace cloudViewer {
15 namespace core {
16 
17 void Matmul(const Tensor& A, const Tensor& B, Tensor& output) {
20 
21  const Device device = A.GetDevice();
22  const Dtype dtype_original = A.GetDtype();
23  Dtype dtype;
24 
25  if (dtype_original != core::Float32 && dtype_original != core::Float64) {
26  utility::LogDebug("Converting to Float32 dtype to from {}.",
27  dtype_original.ToString());
28  dtype = core::Float32;
29  } else {
30  dtype = dtype_original;
31  }
32 
33  // Check shapes
34  SizeVector A_shape = A.GetShape();
35  SizeVector B_shape = B.GetShape();
36 
37  if (A_shape.size() != 2) {
38  utility::LogError("Tensor A must be 2D, but got {}D.", A_shape.size());
39  }
40  if (B_shape.size() != 1 && B_shape.size() != 2) {
42  "Tensor B must be 1D (vector) or 2D (matrix), but got {}D.",
43  B_shape.size());
44  }
45  if (A_shape[1] != B_shape[0]) {
46  utility::LogError("Tensor A columns {} mismatch with Tensor B rows {}.",
47  A_shape[1], B_shape[0]);
48  }
49 
50  // Dispatch to backends
51  int64_t m = A_shape[0];
52  int64_t k = A_shape[1];
53  int64_t n = B_shape.size() == 2 ? B_shape[1] : 1;
54 
55  if (m == 0 || k == 0 || n == 0) {
57  "Tensor shapes should not contain dimensions with zero.");
58  }
59 
60  Tensor A_contiguous = A.Contiguous().To(dtype);
61  Tensor B_contiguous = B.Contiguous().To(dtype);
62  void* A_data = A_contiguous.GetDataPtr();
63  void* B_data = B_contiguous.GetDataPtr();
64 
65  output = Tensor::Empty({m, n}, dtype, device);
66  void* C_data = output.GetDataPtr();
67 
68  if (device.IsSYCL()) {
69 #ifdef BUILD_SYCL_MODULE
70  MatmulSYCL(B_data, A_data, C_data, n, k, m, dtype, device);
71 #else
72  utility::LogError("Unimplemented device.");
73 #endif
74  } else if (device.IsCUDA()) {
75 #ifdef BUILD_CUDA_MODULE
76  CUDAScopedDevice scoped_device(device);
77  MatmulCUDA(B_data, A_data, C_data, n, k, m, dtype, device);
78 #else
79  utility::LogError("Unimplemented device.");
80 #endif
81  } else {
82  MatmulCPU(B_data, A_data, C_data, n, k, m, dtype);
83  }
84 
85  output = output.To(dtype_original);
86 };
87 
88 } // namespace core
89 } // namespace cloudViewer
Common CUDA utilities.
#define AssertTensorDevice(tensor,...)
Definition: TensorCheck.h:45
#define AssertTensorDtype(tensor,...)
Definition: TensorCheck.h:21
When CUDA is not enabled, this is a dummy class.
Definition: CUDAUtils.h:214
bool IsCUDA() const
Returns true iff device type is CUDA.
Definition: Device.h:49
bool IsSYCL() const
Returns true iff device type is SYCL GPU.
Definition: Device.h:52
std::string ToString() const
Definition: Dtype.h:65
Tensor Contiguous() const
Definition: Tensor.cpp:772
Dtype GetDtype() const
Definition: Tensor.h:1164
Device GetDevice() const override
Definition: Tensor.cpp:1435
static Tensor Empty(const SizeVector &shape, Dtype dtype, const Device &device=Device("CPU:0"))
Create a tensor with uninitialized values.
Definition: Tensor.cpp:400
SizeVector GetShape() const
Definition: Tensor.h:1127
Tensor To(Dtype dtype, bool copy=false) const
Definition: Tensor.cpp:739
#define LogError(...)
Definition: Logging.h:60
#define LogDebug(...)
Definition: Logging.h:90
void MatmulCUDA(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, Dtype dtype, const Device &device)
Definition: MatmulCUDA.cpp:17
void MatmulSYCL(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, Dtype dtype, const Device &device)
Definition: MatmulSYCL.cpp:19
void Matmul(const Tensor &A, const Tensor &B, Tensor &output)
Computes matrix multiplication C = AB.
Definition: Matmul.cpp:17
void MatmulCPU(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, Dtype dtype)
Definition: MatmulCPU.cpp:14
const Dtype Float64
Definition: Dtype.cpp:43
const Dtype Float32
Definition: Dtype.cpp:42
Generic file read and write utility for python interface.