10 #include <unordered_map>
39 if (A_shape.
size() != 2) {
42 if (B_shape.
size() != 1 && B_shape.
size() != 2) {
44 "Tensor B must be 1D (vector) or 2D (matrix), but got {}D.",
47 if (A_shape[1] != B_shape[0]) {
49 A_shape[1], B_shape[0]);
51 if (output_shape[0] != A_shape[0] &&
52 output_shape[1] != B_shape[B_shape.
size() - 1]) {
54 "Tensor output must match A rows {} and B columns {}.",
55 A_shape[0], B_shape[B_shape.
size() - 1]);
59 Tensor A_contiguous, B_contiguous;
79 int64_t k = A_contiguous.
GetShape(1);
81 int lda = A_contiguous.
GetStride(transA ? 1 : 0);
82 int ldb = B_contiguous.
GetStride(transB ? 1 : 0);
85 if (m == 0 || k == 0 || n == 0) {
87 "Tensor shapes should not contain dimensions with zero.");
95 #ifdef BUILD_CUDA_MODULE
97 AddMMCUDA(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
98 ldb, lda, ldc, dtype, device);
102 }
else if (device.
IsSYCL()) {
103 #ifdef BUILD_SYCL_MODULE
104 AddMMSYCL(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
105 ldb, lda, ldc, dtype, device);
110 AddMMCPU(B_data, A_data, C_data, n, k, m, alpha, beta, transB, transA,
111 ldb, lda, ldc, dtype);
#define AssertTensorDevice(tensor,...)
#define AssertTensorDtype(tensor,...)
When CUDA is not enabled, this is a dummy class.
bool IsCUDA() const
Returns true iff device type is CUDA.
bool IsSYCL() const
Returns true iff device type is SYCL GPU.
std::string ToString() const
Tensor Contiguous() const
bool IsContiguous() const
int64_t GetStride(int64_t dim) const
Device GetDevice() const override
SizeVector GetShape() const
Tensor T() const
Expects input to be <= 2-D Tensor by swapping dimension 0 and 1.
Tensor To(Dtype dtype, bool copy=false) const
void AddMMCPU(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype)
void AddMM(const Tensor &A, const Tensor &B, Tensor &output, double alpha, double beta)
void AddMMCUDA(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype, const Device &device)
void AddMMSYCL(void *A_data, void *B_data, void *C_data, int64_t m, int64_t k, int64_t n, double alpha, double beta, bool gemmTrA, bool gemmTrB, int lda, int ldb, int ldc, Dtype dtype, const Device &device)
Generic file read and write utility for python interface.