ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
LeastSquaresCUDA.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
14 
15 namespace cloudViewer {
16 namespace core {
17 
18 // cusolverDn<t1><t2>gels() is not supported until CUDA 11.0.
19 // We have to implement for earlier versions via
20 // Step 1: A = Q*R by geqrf.
21 // Step 2: B : = Q ^ T* B by ormqr.
22 // Step 3: solve R* X = B by trsm.
23 // Ref: https://docs.nvidia.com/cuda/cusolver/index.html#ormqr-example1
24 void LeastSquaresCUDA(void* A_data,
25  void* B_data,
26  int64_t m,
27  int64_t n,
28  int64_t k,
29  Dtype dtype,
30  const Device& device) {
31  cusolverDnHandle_t cusolver_handle =
32  CuSolverContext::GetInstance().GetHandle(device);
33  cublasHandle_t cublas_handle =
34  CuBLASContext::GetInstance().GetHandle(device);
35 
37  int len_geqrf, len_ormqr, len;
38  Blob dinfo(sizeof(int), device);
39 
40  OPEN3D_CUSOLVER_CHECK(geqrf_cuda_buffersize<scalar_t>(
41  cusolver_handle, m, n, m, &len_geqrf),
42  "geqrf_buffersize failed in LeastSquaresCUDA");
43  OPEN3D_CUSOLVER_CHECK(ormqr_cuda_buffersize<scalar_t>(
44  cusolver_handle, CUBLAS_SIDE_LEFT,
45  CUBLAS_OP_T, m, k, n, m, m, &len_ormqr),
46  "ormqr_buffersize failed in LeastSquaresCUDA");
47  len = std::max(len_geqrf, len_ormqr);
48 
49  Blob workspace(len * sizeof(scalar_t), device);
50  Blob tau(n * sizeof(scalar_t), device);
51 
52  // Step 1: A = QR
53  OPEN3D_CUSOLVER_CHECK_WITH_DINFO(
54  geqrf_cuda<scalar_t>(
55  cusolver_handle, m, n, static_cast<scalar_t*>(A_data),
56  m, static_cast<scalar_t*>(tau.GetDataPtr()),
57  static_cast<scalar_t*>(workspace.GetDataPtr()), len,
58  static_cast<int*>(dinfo.GetDataPtr())),
59  "geqrf failed in LeastSquaresCUDA",
60  static_cast<int*>(dinfo.GetDataPtr()), device);
61 
62  // Step 2: B' = Q^T*B
63  OPEN3D_CUSOLVER_CHECK_WITH_DINFO(
64  ormqr_cuda<scalar_t>(
65  cusolver_handle, CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, k, n,
66  static_cast<scalar_t*>(A_data), m,
67  static_cast<scalar_t*>(tau.GetDataPtr()),
68  static_cast<scalar_t*>(B_data), m,
69  static_cast<scalar_t*>(workspace.GetDataPtr()), len,
70  static_cast<int*>(dinfo.GetDataPtr())),
71  "ormqr failed in LeastSquaresCUDA",
72  static_cast<int*>(dinfo.GetDataPtr()), device);
73 
74  // Step 3: Solve Rx = B'
75  scalar_t alpha = 1.0f;
76  OPEN3D_CUBLAS_CHECK(
77  trsm_cuda<scalar_t>(cublas_handle, CUBLAS_SIDE_LEFT,
78  CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
79  CUBLAS_DIAG_NON_UNIT, n, k, &alpha,
80  static_cast<scalar_t*>(A_data), m,
81  static_cast<scalar_t*>(B_data), m),
82  "trsm failed in LeastSquaresCUDA");
83  });
84 }
85 
86 } // namespace core
87 } // namespace cloudViewer
Common CUDA utilities.
#define DISPATCH_LINALG_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: LinalgUtils.h:23
void * GetDataPtr()
Definition: Blob.h:75
int max(int a, int b)
Definition: cutil_math.h:48
ccGuiPythonInstance * GetInstance() noexcept
Definition: Runtime.cpp:72
void LeastSquaresCUDA(void *A_data, void *B_data, int64_t m, int64_t n, int64_t k, Dtype dtype, const Device &device)
Generic file read and write utility for python interface.