19 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
20 scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
21 int cols = A.GetShape()[1];
22 int n = A.GetShape()[0] * cols;
24 ParallelFor(A.GetDevice(), n,
25 [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
26 const int64_t idx = workload_idx / cols;
27 const int64_t idy = workload_idx % cols;
28 if (idy - idx >= diagonal) {
29 output_ptr[workload_idx] = A_ptr[idx * cols + idy];
37 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
38 scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
39 int cols = A.GetShape()[1];
40 int n = A.GetShape()[0] * cols;
42 ParallelFor(A.GetDevice(), n,
43 [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
44 const int64_t idx = workload_idx / cols;
45 const int64_t idy = workload_idx % cols;
46 if (idy - idx <= diagonal) {
47 output_ptr[workload_idx] = A_ptr[idx * cols + idy];
58 const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
59 scalar_t *upper_ptr = static_cast<scalar_t *>(upper.GetDataPtr());
60 scalar_t *lower_ptr = static_cast<scalar_t *>(lower.GetDataPtr());
61 int cols = A.GetShape()[1];
62 int n = A.GetShape()[0] * cols;
64 ParallelFor(A.GetDevice(), n,
65 [&] CLOUDVIEWER_DEVICE(int64_t workload_idx) {
66 const int64_t idx = workload_idx / cols;
67 const int64_t idy = workload_idx % cols;
68 if (idy - idx < diagonal) {
69 lower_ptr[workload_idx] = A_ptr[idx * cols + idy];
70 } else if (idy - idx > diagonal) {
71 upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
73 lower_ptr[workload_idx] = 1;
74 upper_ptr[workload_idx] = A_ptr[idx * cols + idy];
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
void TrilCPU(const Tensor &A, Tensor &output, const int diagonal)
void TriulCPU(const Tensor &A, Tensor &upper, Tensor &lower, const int diagonal)
void TriuCPU(const Tensor &A, Tensor &output, const int diagonal)
Generic file read and write utility for python interface.