ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
SolveCUDA.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
14 
15 namespace cloudViewer {
16 namespace core {
17 
18 // cuSolver's gesv will crash when A is a singular matrix.
19 // We implement LU decomposition-based solver (similar to Inverse) instead.
20 void SolveCUDA(void* A_data,
21  void* B_data,
22  void* ipiv_data,
23  int64_t n,
24  int64_t k,
25  Dtype dtype,
26  const Device& device) {
27  cusolverDnHandle_t handle =
28  CuSolverContext::GetInstance().GetHandle(device);
29 
31  int len;
32  Blob dinfo(sizeof(int), device);
33 
34  OPEN3D_CUSOLVER_CHECK(
35  getrf_cuda_buffersize<scalar_t>(handle, n, n, n, &len),
36  "getrf_buffersize failed in SolveCUDA");
37  Blob workspace(len * sizeof(scalar_t), device);
38 
39  OPEN3D_CUSOLVER_CHECK_WITH_DINFO(
40  getrf_cuda<scalar_t>(
41  handle, n, n, static_cast<scalar_t*>(A_data), n,
42  static_cast<scalar_t*>(workspace.GetDataPtr()),
43  static_cast<int*>(ipiv_data),
44  static_cast<int*>(dinfo.GetDataPtr())),
45  "getrf failed in SolveCUDA",
46  static_cast<int*>(dinfo.GetDataPtr()), device);
47 
48  OPEN3D_CUSOLVER_CHECK_WITH_DINFO(
49  getrs_cuda<scalar_t>(handle, CUBLAS_OP_N, n, k,
50  static_cast<scalar_t*>(A_data), n,
51  static_cast<int*>(ipiv_data),
52  static_cast<scalar_t*>(B_data), n,
53  static_cast<int*>(dinfo.GetDataPtr())),
54  "getrs failed in SolveCUDA",
55  static_cast<int*>(dinfo.GetDataPtr()), device);
56  });
57 }
58 
59 } // namespace core
60 } // namespace cloudViewer
Common CUDA utilities.
#define DISPATCH_LINALG_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: LinalgUtils.h:23
void * GetDataPtr()
Definition: Blob.h:75
ccGuiPythonInstance * GetInstance() noexcept
Definition: Runtime.cpp:72
void SolveCUDA(void *A_data, void *B_data, void *ipiv_data, int64_t n, int64_t k, Dtype dtype, const Device &device)
Definition: SolveCUDA.cpp:20
Generic file read and write utility for python interface.