ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
TriCUDA.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "core/Dispatch.h"
9 #include "core/ParallelFor.h"
10 #include "core/Tensor.h"
11 #include "core/linalg/TriImpl.h"
12 
13 namespace cloudViewer {
14 namespace core {
15 
16 void TriuCUDA(const Tensor &A, Tensor &output, const int diagonal) {
17  DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
18  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
19  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
20  int cols = A.GetShape()[1];
21  int n = A.GetShape()[0] * cols;
22 
23  core::ParallelFor(
24  A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
25  const int64_t idx = workload_idx / cols;
26  const int64_t idy = workload_idx % cols;
27  if (idy - idx >= diagonal) {
28  output_ptr[workload_idx] = A_ptr[idx * cols + idy];
29  }
30  });
31  });
32 }
33 
34 void TrilCUDA(const Tensor &A, Tensor &output, const int diagonal) {
35  DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
36  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
37  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
38  int cols = A.GetShape()[1];
39  int n = A.GetShape()[0] * cols;
40 
41  core::ParallelFor(
42  A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
43  const int64_t idx = workload_idx / cols;
44  const int64_t idy = workload_idx % cols;
45  if (idy - idx <= diagonal) {
46  output_ptr[workload_idx] = A_ptr[idx * cols + idy];
47  }
48  });
49  });
50 }
51 
52 void TriulCUDA(const Tensor &A,
53  Tensor &upper,
54  Tensor &lower,
55  const int diagonal) {
56  DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
57  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
58  scalar_t *lower_ptr = static_cast<scalar_t *>(lower.GetDataPtr());
59  scalar_t *upper_ptr = static_cast<scalar_t *>(upper.GetDataPtr());
60  int cols = A.GetShape()[1];
61  int n = A.GetShape()[0] * cols;
62 
63  core::ParallelFor(
64  A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
65  const int64_t idx = workload_idx / cols;
66  const int64_t idy = workload_idx % cols;
67  if (idy - idx < diagonal) {
68  lower_ptr[workload_idx] = A_ptr[idx * cols + idy];
69  } else if (idy - idx > diagonal) {
70  upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
71  } else {
72  lower_ptr[workload_idx] = 1;
73  upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
74  }
75  });
76  });
77 }
78 
79 } // namespace core
80 } // namespace cloudViewer