ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
AddMM.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <unordered_map>
11 
13 
14 namespace cloudViewer {
15 namespace core {
16 
17 void AddMM(const Tensor& A,
18  const Tensor& B,
19  Tensor& output,
20  double alpha,
21  double beta) {
24  AssertTensorDevice(output, A.GetDevice());
25  AssertTensorDtype(output, A.GetDtype());
26 
27  const Device device = A.GetDevice();
28  const Dtype dtype = A.GetDtype();
29 
30  if (dtype != core::Float32 && dtype != core::Float64) {
31  utility::LogError("AddMM is not implemented for {}.", dtype.ToString());
32  }
33 
34  // Check shapes
35  SizeVector A_shape = A.GetShape();
36  SizeVector B_shape = B.GetShape();
37  SizeVector output_shape = output.GetShape();
38 
39  if (A_shape.size() != 2) {
40  utility::LogError("Tensor A must be 2D, but got {}D.", A_shape.size());
41  }
42  if (B_shape.size() != 1 && B_shape.size() != 2) {
44  "Tensor B must be 1D (vector) or 2D (matrix), but got {}D.",
45  B_shape.size());
46  }
47  if (A_shape[1] != B_shape[0]) {
48  utility::LogError("Tensor A columns {} mismatch with Tensor B rows {}.",
49  A_shape[1], B_shape[0]);
50  }
51  if (output_shape[0] != A_shape[0] &&
52  output_shape[1] != B_shape[B_shape.size() - 1]) {
54  "Tensor output must match A rows {} and B columns {}.",
55  A_shape[0], B_shape[B_shape.size() - 1]);
56  }
57 
58  // Check the memory layout of tensors.
59  Tensor A_contiguous, B_contiguous;
60  bool transA = false;
61  bool transB = false;
62  if (A.IsContiguous() || A.T().IsContiguous()) {
63  transA = A.T().IsContiguous();
64  A_contiguous = A;
65  } else {
66  A_contiguous = A.Contiguous();
67  }
68 
69  if (B.IsContiguous() || B.T().IsContiguous()) {
70  transB = B.T().IsContiguous();
71  B_contiguous = B;
72  } else {
73  B_contiguous = B.Contiguous();
74  }
75 
76  // Dispatch to backends
77  int64_t m = output.GetShape(0);
78  int64_t n = output.GetShape(1);
79  int64_t k = A_contiguous.GetShape(1);
80 
81  int lda = A_contiguous.GetStride(transA ? 1 : 0);
82  int ldb = B_contiguous.GetStride(transB ? 1 : 0);
83  int ldc = output.GetStride(0);
84 
85  if (m == 0 || k == 0 || n == 0) {
87  "Tensor shapes should not contain dimensions with zero.");
88  }
89 
90  void* A_data = A_contiguous.To(dtype).GetDataPtr();
91  void* B_data = B_contiguous.To(dtype).GetDataPtr();
92  void* C_data = output.GetDataPtr();
93 
94  if (device.IsCUDA()) {
95 #ifdef BUILD_CUDA_MODULE
96  CUDAScopedDevice scoped_device(device);
97  AddMMCUDA(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
98  ldb, lda, ldc, dtype, device);
99 #else
100  utility::LogError("Unimplemented device.");
101 #endif
102  } else if (device.IsSYCL()) {
103 #ifdef BUILD_SYCL_MODULE
104  AddMMSYCL(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
105  ldb, lda, ldc, dtype, device);
106 #else
107  utility::LogError("Unimplemented device.");
108 #endif
109  } else {
110  AddMMCPU(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
111  ldb, lda, ldc, dtype);
112  }
113 };
114 
115 } // namespace core
116 } // namespace cloudViewer
Common CUDA utilities.
#define AssertTensorDevice(tensor,...)
Definition: TensorCheck.h:45
#define AssertTensorDtype(tensor,...)
Definition: TensorCheck.h:21
When CUDA is not enabled, this is a dummy class.
Definition: CUDAUtils.h:214
bool IsCUDA() const
Returns true iff device type is CUDA.
Definition: Device.h:49
bool IsSYCL() const
Returns true iff device type is SYCL GPU.
Definition: Device.h:52
std::string ToString() const
Definition: Dtype.h:65
Tensor Contiguous() const
Definition: Tensor.cpp:772
bool IsContiguous() const
Definition: Tensor.h:1036
Dtype GetDtype() const
Definition: Tensor.h:1164
int64_t GetStride(int64_t dim) const
Definition: Tensor.h:1139
Device GetDevice() const override
Definition: Tensor.cpp:1435
SizeVector GetShape() const
Definition: Tensor.h:1127
Tensor T() const
Expects input to be <= 2-D Tensor by swapping dimension 0 and 1.
Definition: Tensor.cpp:1079
Tensor To(Dtype dtype, bool copy=false) const
Definition: Tensor.cpp:739
#define LogError(...)
Definition: Logging.h:60
void AddMMCPU(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype)
Definition: AddMMCPU.cpp:17
void AddMM(const Tensor &A, const Tensor &B, Tensor &output, double alpha, double beta)
Definition: AddMM.cpp:17
const Dtype Float64
Definition: Dtype.cpp:43
void AddMMCUDA(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype, const Device &device)
Definition: AddMMCUDA.cpp:17
void AddMMSYCL(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype, const Device &device)
Definition: AddMMSYCL.cpp:20
const Dtype Float32
Definition: Dtype.cpp:42
Generic file read and write utility for python interface.