ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
TriSYCL.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
12 
13 namespace cloudViewer {
14 namespace core {
15 
16 void TriuSYCL(const Tensor &A, Tensor &output, const int diagonal) {
17  sycl::queue queue =
20  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
21  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
22  auto rows = static_cast<size_t>(A.GetShape()[0]),
23  cols = static_cast<size_t>(A.GetShape()[1]);
24  queue.parallel_for({cols, rows}, [=](auto wid) {
25  const auto wid_1d = wid[1] * cols + wid[0];
26  if (static_cast<int>(wid[0]) - static_cast<int>(wid[1]) >=
27  diagonal) {
28  output_ptr[wid_1d] = A_ptr[wid_1d];
29  }
30  }).wait_and_throw();
31  });
32 }
33 
34 void TrilSYCL(const Tensor &A, Tensor &output, const int diagonal) {
35  sycl::queue queue =
36  sy::SYCLContext::GetInstance().GetDefaultQueue(A.GetDevice());
38  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
39  scalar_t *output_ptr = static_cast<scalar_t *>(output.GetDataPtr());
40  auto rows = static_cast<size_t>(A.GetShape()[0]),
41  cols = static_cast<size_t>(A.GetShape()[1]);
42  queue.parallel_for({cols, rows}, [=](auto wid) {
43  const auto wid_1d = wid[1] * cols + wid[0];
44  if (static_cast<int>(wid[0]) - static_cast<int>(wid[1]) <=
45  diagonal) {
46  output_ptr[wid_1d] = A_ptr[wid_1d];
47  }
48  }).wait_and_throw();
49  });
50 }
51 
52 void TriulSYCL(const Tensor &A,
53  Tensor &upper,
54  Tensor &lower,
55  const int diagonal) {
56  sycl::queue queue =
57  sy::SYCLContext::GetInstance().GetDefaultQueue(A.GetDevice());
59  const scalar_t *A_ptr = static_cast<const scalar_t *>(A.GetDataPtr());
60  scalar_t *upper_ptr = static_cast<scalar_t *>(upper.GetDataPtr());
61  scalar_t *lower_ptr = static_cast<scalar_t *>(lower.GetDataPtr());
62  auto rows = static_cast<size_t>(A.GetShape()[0]),
63  cols = static_cast<size_t>(A.GetShape()[1]);
64  queue.parallel_for({cols, rows}, [=](auto wid) {
65  const auto wid_1d = wid[1] * cols + wid[0];
66  if (static_cast<int>(wid[0]) - static_cast<int>(wid[1]) <
67  diagonal) {
68  lower_ptr[wid_1d] = A_ptr[wid_1d];
69  } else if (static_cast<int>(wid[0]) -
70  static_cast<int>(wid[1]) >
71  diagonal) {
72  upper_ptr[wid_1d] = A_ptr[wid_1d];
73  } else {
74  lower_ptr[wid_1d] = 1;
75  upper_ptr[wid_1d] = A_ptr[wid_1d];
76  }
77  }).wait_and_throw();
78  });
79 }
80 
81 } // namespace core
82 } // namespace cloudViewer
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: Dispatch.h:31
SYCL queue manager.
Dtype GetDtype() const
Definition: Tensor.h:1164
Device GetDevice() const override
Definition: Tensor.cpp:1435
static SYCLContext & GetInstance()
Get singleton instance.
Definition: SYCLContext.cpp:25
sycl::queue GetDefaultQueue(const Device &device)
Get the default SYCL queue given an CloudViewer device.
Definition: SYCLContext.cpp:43
ccGuiPythonInstance * GetInstance() noexcept
Definition: Runtime.cpp:72
void TrilSYCL(const Tensor &A, Tensor &output, const int diagonal)
Definition: TriSYCL.cpp:34
void TriulSYCL(const Tensor &A, Tensor &upper, Tensor &lower, const int diagonal)
Definition: TriSYCL.cpp:52
void TriuSYCL(const Tensor &A, Tensor &output, const int diagonal)
Definition: TriSYCL.cpp:16
Generic file read and write utility for python interface.