ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
TriCPU.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
13 
14 namespace cloudViewer {
15 namespace core {
16 
17 void TriuCPU(const Tensor &A, Tensor &output, const int diagonal) {
19  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
20  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
21  int cols = A.GetShape()[1];
22  int n = A.GetShape()[0] * cols;
23 
24  ParallelFor(A.GetDevice(), n,
25  [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
26  const int64_t idx = workload_idx / cols;
27  const int64_t idy = workload_idx % cols;
28  if (idy - idx >= diagonal) {
29  output_ptr[workload_idx] = A_ptr[idx * cols + idy];
30  }
31  });
32  });
33 }
34 
35 void TrilCPU(const Tensor &A, Tensor &output, const int diagonal) {
37  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
38  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
39  int cols = A.GetShape()[1];
40  int n = A.GetShape()[0] * cols;
41 
42  ParallelFor(A.GetDevice(), n,
43  [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
44  const int64_t idx = workload_idx / cols;
45  const int64_t idy = workload_idx % cols;
46  if (idy - idx <= diagonal) {
47  output_ptr[workload_idx] = A_ptr[idx * cols + idy];
48  }
49  });
50  });
51 }
52 
53 void TriulCPU(const Tensor &A,
54  Tensor &upper,
55  Tensor &lower,
56  const int diagonal) {
58  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
59  scalar_t *upper_ptr = static_cast<scalar_t *>(upper.GetDataPtr());
60  scalar_t *lower_ptr = static_cast<scalar_t *>(lower.GetDataPtr());
61  int cols = A.GetShape()[1];
62  int n = A.GetShape()[0] * cols;
63 
64  ParallelFor(A.GetDevice(), n,
65  [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
66  const int64_t idx = workload_idx / cols;
67  const int64_t idy = workload_idx % cols;
68  if (idy - idx < diagonal) {
69  lower_ptr[workload_idx] = A_ptr[idx * cols + idy];
70  } else if (idy - idx > diagonal) {
71  upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
72  } else {
73  lower_ptr[workload_idx] = 1;
74  upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
75  }
76  });
77  });
78 }
79 
80 } // namespace core
81 } // namespace cloudViewer
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: Dispatch.h:31
Dtype GetDtype() const
Definition: Tensor.h:1164
void TrilCPU(const Tensor &A, Tensor &output, const int diagonal)
Definition: TriCPU.cpp:35
void TriulCPU(const Tensor &A, Tensor &upper, Tensor &lower, const int diagonal)
Definition: TriCPU.cpp:53
void TriuCPU(const Tensor &A, Tensor &output, const int diagonal)
Definition: TriCPU.cpp:17
Generic file read and write utility for python interface.