ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
BinaryEWSYCL.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <Logging.h>
9 
11 #include "cloudViewer/core/Dtype.h"
18 
19 namespace cloudViewer {
20 namespace core {
21 namespace kernel {
22 
23 namespace {
24 
25 struct BinaryElementKernel {
26  void operator()(int64_t i) {}
27  BinaryElementKernel(Indexer indexer_) : indexer(indexer_) {}
28 
29 protected:
30  Indexer indexer;
31 };
32 
33 // Min, Max
34 #define BINARY_ELEMENT_KERNEL(name, elem_fn) \
35  template <typename src_t, typename dst_t = src_t> \
36  struct name##ElementKernel : public BinaryElementKernel { \
37  using BinaryElementKernel::BinaryElementKernel; \
38  void operator()(int64_t i) { \
39  const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
40  const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
41  dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
42  *dst = elem_fn(*lhs, *rhs); \
43  } \
44  }
45 
48 #undef BINARY_ELEMENT_KERNEL
49 
51 template <>
52 struct MaxElementKernel<bool, bool> : public BinaryElementKernel {
53  using BinaryElementKernel::BinaryElementKernel;
54  void operator()(int64_t i) {
55  const bool* lhs = indexer.GetInputPtr<bool>(0, i);
56  const bool* rhs = indexer.GetInputPtr<bool>(1, i);
57  bool* dst = indexer.GetOutputPtr<bool>(i);
58  *dst = *lhs || *rhs;
59  }
60 };
61 template <>
62 struct MinElementKernel<bool, bool> : public BinaryElementKernel {
63  using BinaryElementKernel::BinaryElementKernel;
64  void operator()(int64_t i) {
65  const bool* lhs = indexer.GetInputPtr<bool>(0, i);
66  const bool* rhs = indexer.GetInputPtr<bool>(1, i);
67  bool* dst = indexer.GetOutputPtr<bool>(i);
68  *dst = *lhs && *rhs;
69  }
70 };
71 
72 // Arithmetic and Relational ops.
73 #define BINARY_ELEMENT_KERNEL(name, elem_op) \
74  template <typename src_t, typename dst_t = src_t> \
75  struct name##ElementKernel : public BinaryElementKernel { \
76  using BinaryElementKernel::BinaryElementKernel; \
77  void operator()(int64_t i) { \
78  const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
79  const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
80  dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
81  *dst = (*lhs)elem_op(*rhs); \
82  } \
83  }
84 
85 BINARY_ELEMENT_KERNEL(Add, +);
86 BINARY_ELEMENT_KERNEL(Sub, -);
87 BINARY_ELEMENT_KERNEL(Mul, *);
88 BINARY_ELEMENT_KERNEL(Div, /);
91 BINARY_ELEMENT_KERNEL(Geq, >=);
92 BINARY_ELEMENT_KERNEL(Leq, <=);
93 BINARY_ELEMENT_KERNEL(Eq, ==);
94 BINARY_ELEMENT_KERNEL(Neq, !=);
95 #undef BINARY_ELEMENT_KERNEL
96 
97 // Logical ops.
98 #define BINARY_ELEMENT_KERNEL(name, elem_op) \
99  template <typename src_t, typename dst_t = src_t> \
100  struct name##ElementKernel : public BinaryElementKernel { \
101  using BinaryElementKernel::BinaryElementKernel; \
102  void operator()(int64_t i) { \
103  const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
104  const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
105  dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
106  *dst = static_cast<bool>(*lhs) elem_op static_cast<bool>(*rhs); \
107  } \
108  }
109 BINARY_ELEMENT_KERNEL(LogicalAnd, &&);
110 BINARY_ELEMENT_KERNEL(LogicalOr, ||);
111 BINARY_ELEMENT_KERNEL(LogicalXor, !=);
112 #undef BINARY_ELEMENT_KERNEL
113 
114 } // namespace
115 
116 void BinaryEWSYCL(const Tensor& lhs,
117  const Tensor& rhs,
118  Tensor& dst,
119  BinaryEWOpCode op_code) {
120  Dtype src_dtype = lhs.GetDtype();
121  Dtype dst_dtype = dst.GetDtype();
122  Device device = lhs.GetDevice();
123 
124  if (s_boolean_binary_ew_op_codes.find(op_code) !=
126  if (dst_dtype == src_dtype) {
127  // Inplace boolean op's output type is the same as the
128  // input. e.g. np.logical_and(a, b, out=a), where a, b are
129  // floats.
130  Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
131  DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src_dtype, [&]() {
132  switch (op_code) {
134  ParallelForSYCL<LogicalAndElementKernel<scalar_t>>(
135  device, indexer);
136  break;
138  ParallelForSYCL<LogicalOrElementKernel<scalar_t>>(
139  device, indexer);
140  break;
142  ParallelForSYCL<LogicalXorElementKernel<scalar_t>>(
143  device, indexer);
144  break;
145  case BinaryEWOpCode::Gt:
146  ParallelForSYCL<GtElementKernel<scalar_t>>(device,
147  indexer);
148  break;
149  case BinaryEWOpCode::Lt:
150  ParallelForSYCL<LtElementKernel<scalar_t>>(device,
151  indexer);
152  break;
153  case BinaryEWOpCode::Ge:
154  ParallelForSYCL<GeqElementKernel<scalar_t>>(device,
155  indexer);
156  break;
157  case BinaryEWOpCode::Le:
158  ParallelForSYCL<LeqElementKernel<scalar_t>>(device,
159  indexer);
160  break;
161  case BinaryEWOpCode::Eq:
162  ParallelForSYCL<EqElementKernel<scalar_t>>(device,
163  indexer);
164  break;
165  case BinaryEWOpCode::Ne:
166  ParallelForSYCL<NeqElementKernel<scalar_t>>(device,
167  indexer);
168  break;
169  default:
170  break;
171  }
172  });
173  } else if (dst_dtype == core::Bool) {
174  // By default, output is boolean type.
175  Indexer indexer({lhs, rhs}, dst,
177  DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src_dtype, [&]() {
178  switch (op_code) {
181  LogicalAndElementKernel<scalar_t, bool>>(
182  device, indexer);
183  break;
185  ParallelForSYCL<LogicalOrElementKernel<scalar_t, bool>>(
186  device, indexer);
187  break;
190  LogicalXorElementKernel<scalar_t, bool>>(
191  device, indexer);
192  break;
193  case BinaryEWOpCode::Gt:
194  ParallelForSYCL<GtElementKernel<scalar_t, bool>>(
195  device, indexer);
196  break;
197  case BinaryEWOpCode::Lt:
198  ParallelForSYCL<LtElementKernel<scalar_t, bool>>(
199  device, indexer);
200  break;
201  case BinaryEWOpCode::Ge:
202  ParallelForSYCL<GeqElementKernel<scalar_t, bool>>(
203  device, indexer);
204  break;
205  case BinaryEWOpCode::Le:
206  ParallelForSYCL<LeqElementKernel<scalar_t, bool>>(
207  device, indexer);
208  break;
209  case BinaryEWOpCode::Eq:
210  ParallelForSYCL<EqElementKernel<scalar_t, bool>>(
211  device, indexer);
212  break;
213  case BinaryEWOpCode::Ne:
214  ParallelForSYCL<NeqElementKernel<scalar_t, bool>>(
215  device, indexer);
216  break;
217  default:
218  break;
219  }
220  });
221  } else {
223  "Boolean op's output type must be boolean or the "
224  "same type as the input.");
225  }
226  } else if (op_code == BinaryEWOpCode::Maximum ||
227  op_code == BinaryEWOpCode::Minimum) {
228  Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
229  DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src_dtype, [&]() {
230  switch (op_code) {
232  ParallelForSYCL<MaxElementKernel<scalar_t>>(device,
233  indexer);
234  break;
236  ParallelForSYCL<MinElementKernel<scalar_t>>(device,
237  indexer);
238  break;
239  default:
240  break;
241  }
242  });
243  } else {
244  Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
245  DISPATCH_DTYPE_TO_TEMPLATE(src_dtype, [&]() {
246  switch (op_code) {
247  case BinaryEWOpCode::Add:
248  ParallelForSYCL<AddElementKernel<scalar_t>>(device,
249  indexer);
250  break;
251  case BinaryEWOpCode::Sub:
252  ParallelForSYCL<SubElementKernel<scalar_t>>(device,
253  indexer);
254  break;
255  case BinaryEWOpCode::Mul:
256  ParallelForSYCL<MulElementKernel<scalar_t>>(device,
257  indexer);
258  break;
259  case BinaryEWOpCode::Div:
260  ParallelForSYCL<DivElementKernel<scalar_t>>(device,
261  indexer);
262  break;
263  default:
264  break;
265  }
266  });
267  }
268 }
269 } // namespace kernel
270 } // namespace core
271 } // namespace cloudViewer
Indexer indexer
#define BINARY_ELEMENT_KERNEL(name, elem_fn)
#define DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(DTYPE,...)
Definition: Dispatch.h:68
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
Definition: Dispatch.h:31
Dtype GetDtype() const
Definition: Tensor.h:1164
Device GetDevice() const override
Definition: Tensor.cpp:1435
#define LogError(...)
Definition: Logging.h:60
int min(int a, int b)
Definition: cutil_math.h:53
int max(int a, int b)
Definition: cutil_math.h:48
const std::unordered_set< BinaryEWOpCode, utility::hash_enum_class > s_boolean_binary_ew_op_codes
Definition: BinaryEW.cpp:22
void BinaryEWSYCL(const Tensor &lhs, const Tensor &rhs, Tensor &dst, BinaryEWOpCode op_code)
const Dtype Bool
Definition: Dtype.cpp:52
void ParallelForSYCL(const Device &device, Indexer indexer, FuncArgs... func_args)
Run a function in parallel with SYCL.
Generic file read and write utility for python interface.