1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "core/Dispatch.h"
9 #include "core/ParallelFor.h"
10 #include "core/Tensor.h"
11 #include "core/linalg/TriImpl.h"
13 namespace cloudViewer {
16 void TriuCUDA(const Tensor &A, Tensor &output, const int diagonal) {
17 DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
18 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
19 scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
20 int cols = A.GetShape()[1];
21 int n = A.GetShape()[0] * cols;
24 A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
25 const int64_t idx = workload_idx / cols;
26 const int64_t idy = workload_idx % cols;
27 if (idy - idx >= diagonal) {
28 output_ptr[workload_idx] = A_ptr[idx * cols + idy];
34 void TrilCUDA(const Tensor &A, Tensor &output, const int diagonal) {
35 DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
36 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
37 scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
38 int cols = A.GetShape()[1];
39 int n = A.GetShape()[0] * cols;
42 A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
43 const int64_t idx = workload_idx / cols;
44 const int64_t idy = workload_idx % cols;
45 if (idy - idx <= diagonal) {
46 output_ptr[workload_idx] = A_ptr[idx * cols + idy];
52 void TriulCUDA(const Tensor &A,
56 DISPATCH_DTYPE_TO_TEMPLATE(A.GetDtype(), [&]() {
57 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
58 scalar_t *lower_ptr = static_cast<scalar_t *>(lower.GetDataPtr());
59 scalar_t *upper_ptr = static_cast<scalar_t *>(upper.GetDataPtr());
60 int cols = A.GetShape()[1];
61 int n = A.GetShape()[0] * cols;
64 A.GetDevice(), n, [=] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
65 const int64_t idx = workload_idx / cols;
66 const int64_t idy = workload_idx % cols;
67 if (idy - idx < diagonal) {
68 lower_ptr[workload_idx] = A_ptr[idx * cols + idy];
69 } else if (idy - idx > diagonal) {
70 upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
72 lower_ptr[workload_idx] = 1;
73 upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
80 } // namespace cloudViewer