25 struct BinaryElementKernel {
26 void operator()(int64_t i) {}
27 BinaryElementKernel(Indexer indexer_) :
indexer(indexer_) {}
34 #define BINARY_ELEMENT_KERNEL(name, elem_fn) \
35 template <typename src_t, typename dst_t = src_t> \
36 struct name##ElementKernel : public BinaryElementKernel { \
37 using BinaryElementKernel::BinaryElementKernel; \
38 void operator()(int64_t i) { \
39 const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
40 const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
41 dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
42 *dst = elem_fn(*lhs, *rhs); \
48 #undef BINARY_ELEMENT_KERNEL
52 struct MaxElementKernel<bool, bool> :
public BinaryElementKernel {
53 using BinaryElementKernel::BinaryElementKernel;
54 void operator()(int64_t i) {
55 const bool* lhs =
indexer.GetInputPtr<
bool>(0, i);
56 const bool* rhs =
indexer.GetInputPtr<
bool>(1, i);
57 bool* dst =
indexer.GetOutputPtr<
bool>(i);
62 struct MinElementKernel<bool, bool> :
public BinaryElementKernel {
63 using BinaryElementKernel::BinaryElementKernel;
64 void operator()(int64_t i) {
65 const bool* lhs =
indexer.GetInputPtr<
bool>(0, i);
66 const bool* rhs =
indexer.GetInputPtr<
bool>(1, i);
67 bool* dst =
indexer.GetOutputPtr<
bool>(i);
73 #define BINARY_ELEMENT_KERNEL(name, elem_op) \
74 template <typename src_t, typename dst_t = src_t> \
75 struct name##ElementKernel : public BinaryElementKernel { \
76 using BinaryElementKernel::BinaryElementKernel; \
77 void operator()(int64_t i) { \
78 const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
79 const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
80 dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
81 *dst = (*lhs)elem_op(*rhs); \
95 #undef BINARY_ELEMENT_KERNEL
98 #define BINARY_ELEMENT_KERNEL(name, elem_op) \
99 template <typename src_t, typename dst_t = src_t> \
100 struct name##ElementKernel : public BinaryElementKernel { \
101 using BinaryElementKernel::BinaryElementKernel; \
102 void operator()(int64_t i) { \
103 const src_t* lhs = indexer.GetInputPtr<src_t>(0, i); \
104 const src_t* rhs = indexer.GetInputPtr<src_t>(1, i); \
105 dst_t* dst = indexer.GetOutputPtr<dst_t>(i); \
106 *dst = static_cast<bool>(*lhs) elem_op static_cast<bool>(*rhs); \
112 #undef BINARY_ELEMENT_KERNEL
126 if (dst_dtype == src_dtype) {
134 ParallelForSYCL<LogicalAndElementKernel<scalar_t>>(
138 ParallelForSYCL<LogicalOrElementKernel<scalar_t>>(
142 ParallelForSYCL<LogicalXorElementKernel<scalar_t>>(
146 ParallelForSYCL<GtElementKernel<scalar_t>>(device,
150 ParallelForSYCL<LtElementKernel<scalar_t>>(device,
154 ParallelForSYCL<GeqElementKernel<scalar_t>>(device,
158 ParallelForSYCL<LeqElementKernel<scalar_t>>(device,
162 ParallelForSYCL<EqElementKernel<scalar_t>>(device,
166 ParallelForSYCL<NeqElementKernel<scalar_t>>(device,
181 LogicalAndElementKernel<scalar_t, bool>>(
185 ParallelForSYCL<LogicalOrElementKernel<scalar_t, bool>>(
190 LogicalXorElementKernel<scalar_t, bool>>(
194 ParallelForSYCL<GtElementKernel<scalar_t, bool>>(
198 ParallelForSYCL<LtElementKernel<scalar_t, bool>>(
202 ParallelForSYCL<GeqElementKernel<scalar_t, bool>>(
206 ParallelForSYCL<LeqElementKernel<scalar_t, bool>>(
210 ParallelForSYCL<EqElementKernel<scalar_t, bool>>(
214 ParallelForSYCL<NeqElementKernel<scalar_t, bool>>(
223 "Boolean op's output type must be boolean or the "
224 "same type as the input.");
232 ParallelForSYCL<MaxElementKernel<scalar_t>>(device,
236 ParallelForSYCL<MinElementKernel<scalar_t>>(device,
248 ParallelForSYCL<AddElementKernel<scalar_t>>(device,
252 ParallelForSYCL<SubElementKernel<scalar_t>>(device,
256 ParallelForSYCL<MulElementKernel<scalar_t>>(device,
260 ParallelForSYCL<DivElementKernel<scalar_t>>(device,
#define BINARY_ELEMENT_KERNEL(name, elem_fn)
#define DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(DTYPE,...)
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
Device GetDevice() const override
const std::unordered_set< BinaryEWOpCode, utility::hash_enum_class > s_boolean_binary_ew_op_codes
void BinaryEWSYCL(const Tensor &lhs, const Tensor &rhs, Tensor &dst, BinaryEWOpCode op_code)
void ParallelForSYCL(const Device &device, Indexer indexer, FuncArgs... func_args)
Run a function in parallel with SYCL.
Generic file read and write utility for python interface.