ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
BlasWrapper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include <Logging.h>
11 
14 
15 namespace cloudViewer {
16 namespace core {
17 
18 template <typename scalar_t>
19 inline void gemm_cpu(CBLAS_LAYOUT layout,
20  CBLAS_TRANSPOSE trans_A,
21  CBLAS_TRANSPOSE trans_B,
25  scalar_t alpha,
26  const scalar_t *A_data,
28  const scalar_t *B_data,
30  scalar_t beta,
31  scalar_t *C_data,
33  utility::LogError("Unsupported data type.");
34 }
35 
36 template <>
37 inline void gemm_cpu<float>(CBLAS_LAYOUT layout,
38  CBLAS_TRANSPOSE trans_A,
39  CBLAS_TRANSPOSE trans_B,
43  float alpha,
44  const float *A_data,
46  const float *B_data,
48  float beta,
49  float *C_data,
51  cblas_sgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
52  ldb, beta, C_data, ldc);
53 }
54 
55 template <>
56 inline void gemm_cpu<double>(CBLAS_LAYOUT layout,
57  CBLAS_TRANSPOSE trans_A,
58  CBLAS_TRANSPOSE trans_B,
62  double alpha,
63  const double *A_data,
65  const double *B_data,
67  double beta,
68  double *C_data,
70  cblas_dgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
71  ldb, beta, C_data, ldc);
72 }
73 
74 #ifdef BUILD_CUDA_MODULE
75 template <typename scalar_t>
76 inline cublasStatus_t gemm_cuda(cublasHandle_t handle,
77  cublasOperation_t transa,
78  cublasOperation_t transb,
79  int m,
80  int n,
81  int k,
82  const scalar_t *alpha,
83  const scalar_t *A_data,
84  int lda,
85  const scalar_t *B_data,
86  int ldb,
87  const scalar_t *beta,
88  scalar_t *C_data,
89  int ldc) {
90  utility::LogError("Unsupported data type.");
91  return CUBLAS_STATUS_NOT_SUPPORTED;
92 }
93 
94 template <typename scalar_t>
95 inline cublasStatus_t trsm_cuda(cublasHandle_t handle,
96  cublasSideMode_t side,
97  cublasFillMode_t uplo,
98  cublasOperation_t trans,
99  cublasDiagType_t diag,
100  int m,
101  int n,
102  const scalar_t *alpha,
103  const scalar_t *A,
104  int lda,
105  scalar_t *B,
106  int ldb) {
107  utility::LogError("Unsupported data type.");
108  return CUBLAS_STATUS_NOT_SUPPORTED;
109 }
110 
111 template <>
112 inline cublasStatus_t gemm_cuda<float>(cublasHandle_t handle,
113  cublasOperation_t transa,
114  cublasOperation_t transb,
115  int m,
116  int n,
117  int k,
118  const float *alpha,
119  const float *A_data,
120  int lda,
121  const float *B_data,
122  int ldb,
123  const float *beta,
124  float *C_data,
125  int ldc) {
126  return cublasSgemm(handle, transa,
127  transb, // A, B transpose flag
128  m, n, k, // dimensions
129  alpha, static_cast<const float *>(A_data), lda,
130  static_cast<const float *>(B_data),
131  ldb, // input and their leading dims
132  beta, static_cast<float *>(C_data), ldc);
133 }
134 
135 template <>
136 inline cublasStatus_t gemm_cuda<double>(cublasHandle_t handle,
137  cublasOperation_t transa,
138  cublasOperation_t transb,
139  int m,
140  int n,
141  int k,
142  const double *alpha,
143  const double *A_data,
144  int lda,
145  const double *B_data,
146  int ldb,
147  const double *beta,
148  double *C_data,
149  int ldc) {
150  return cublasDgemm(handle, transa,
151  transb, // A, B transpose flag
152  m, n, k, // dimensions
153  alpha, static_cast<const double *>(A_data), lda,
154  static_cast<const double *>(B_data),
155  ldb, // input and their leading dims
156  beta, static_cast<double *>(C_data), ldc);
157 }
158 
159 template <>
160 inline cublasStatus_t trsm_cuda<float>(cublasHandle_t handle,
161  cublasSideMode_t side,
162  cublasFillMode_t uplo,
163  cublasOperation_t trans,
164  cublasDiagType_t diag,
165  int m,
166  int n,
167  const float *alpha,
168  const float *A,
169  int lda,
170  float *B,
171  int ldb) {
172  return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
173  ldb);
174 }
175 
176 template <>
177 inline cublasStatus_t trsm_cuda<double>(cublasHandle_t handle,
178  cublasSideMode_t side,
179  cublasFillMode_t uplo,
180  cublasOperation_t trans,
181  cublasDiagType_t diag,
182  int m,
183  int n,
184  const double *alpha,
185  const double *A,
186  int lda,
187  double *B,
188  int ldb) {
189  return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
190  ldb);
191 }
192 #endif
193 
194 } // namespace core
195 } // namespace cloudViewer
#define CLOUDVIEWER_CPU_LINALG_INT
#define LogError(...)
Definition: Logging.h:60
void gemm_cpu(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, CLOUDVIEWER_CPU_LINALG_INT m, CLOUDVIEWER_CPU_LINALG_INT n, CLOUDVIEWER_CPU_LINALG_INT k, scalar_t alpha, const scalar_t *A_data, CLOUDVIEWER_CPU_LINALG_INT lda, const scalar_t *B_data, CLOUDVIEWER_CPU_LINALG_INT ldb, scalar_t beta, scalar_t *C_data, CLOUDVIEWER_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:19
void gemm_cpu< float >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, CLOUDVIEWER_CPU_LINALG_INT m, CLOUDVIEWER_CPU_LINALG_INT n, CLOUDVIEWER_CPU_LINALG_INT k, float alpha, const float *A_data, CLOUDVIEWER_CPU_LINALG_INT lda, const float *B_data, CLOUDVIEWER_CPU_LINALG_INT ldb, float beta, float *C_data, CLOUDVIEWER_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:37
void gemm_cpu< double >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, CLOUDVIEWER_CPU_LINALG_INT m, CLOUDVIEWER_CPU_LINALG_INT n, CLOUDVIEWER_CPU_LINALG_INT k, double alpha, const double *A_data, CLOUDVIEWER_CPU_LINALG_INT lda, const double *B_data, CLOUDVIEWER_CPU_LINALG_INT ldb, double beta, double *C_data, CLOUDVIEWER_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:56
Generic file read and write utility for python interface.