19 #ifdef BUILD_ISPC_MODULE
20 #include "BinaryEWCPU_ispc.h"
27 template <
typename src_t,
typename dst_t,
typename element_func_t>
29 const element_func_t& element_func) {
31 [&
indexer, &element_func](int64_t i) {
32 element_func(indexer.GetInputPtr<src_t>(0, i),
33 indexer.GetInputPtr<src_t>(1, i),
34 indexer.GetOutputPtr<dst_t>(i));
38 template <
typename src_t,
40 typename element_func_t,
43 const element_func_t& element_func,
44 const vec_func_t& vec_func) {
47 [&
indexer, &element_func](int64_t i) {
48 element_func(indexer.GetInputPtr<src_t>(0, i),
49 indexer.GetInputPtr<src_t>(1, i),
50 indexer.GetOutputPtr<dst_t>(i));
55 template <
typename scalar_t>
57 *
static_cast<scalar_t*
>(dst) =
std::max(*
static_cast<const scalar_t*
>(lhs),
58 *
static_cast<const scalar_t*
>(rhs));
61 template <
typename scalar_t>
63 *
static_cast<scalar_t*
>(dst) =
std::min(*
static_cast<const scalar_t*
>(lhs),
64 *
static_cast<const scalar_t*
>(rhs));
67 template <
typename scalar_t>
69 *
static_cast<scalar_t*
>(dst) = *
static_cast<const scalar_t*
>(lhs) +
70 *
static_cast<const scalar_t*
>(rhs);
73 template <
typename scalar_t>
75 *
static_cast<scalar_t*
>(dst) = *
static_cast<const scalar_t*
>(lhs) -
76 *
static_cast<const scalar_t*
>(rhs);
79 template <
typename scalar_t>
81 *
static_cast<scalar_t*
>(dst) = *
static_cast<const scalar_t*
>(lhs) *
82 *
static_cast<const scalar_t*
>(rhs);
85 template <
typename scalar_t>
87 *
static_cast<scalar_t*
>(dst) = *
static_cast<const scalar_t*
>(lhs) /
88 *
static_cast<const scalar_t*
>(rhs);
91 template <
typename src_t,
typename dst_t>
95 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
96 static_cast<bool>(*
static_cast<const src_t*
>(lhs)) &&
97 static_cast<bool>(*
static_cast<const src_t*
>(rhs)));
100 template <
typename src_t,
typename dst_t>
104 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
105 static_cast<bool>(*
static_cast<const src_t*
>(lhs)) ||
106 static_cast<bool>(*
static_cast<const src_t*
>(rhs)));
109 template <
typename src_t,
typename dst_t>
113 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
114 static_cast<bool>(*
static_cast<const src_t*
>(lhs)) !=
115 static_cast<bool>(*
static_cast<const src_t*
>(rhs)));
118 template <
typename src_t,
typename dst_t>
120 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
121 *
static_cast<const src_t*
>(lhs) > *
static_cast<const src_t*
>(rhs));
124 template <
typename src_t,
typename dst_t>
126 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
127 *
static_cast<const src_t*
>(lhs) < *
static_cast<const src_t*
>(rhs));
130 template <
typename src_t,
typename dst_t>
132 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
133 *
static_cast<const src_t*
>(lhs) >= *
static_cast<const src_t*
>(rhs));
136 template <
typename src_t,
typename dst_t>
138 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
139 *
static_cast<const src_t*
>(lhs) <= *
static_cast<const src_t*
>(rhs));
142 template <
typename src_t,
typename dst_t>
144 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
145 *
static_cast<const src_t*
>(lhs) == *
static_cast<const src_t*
>(rhs));
148 template <
typename src_t,
typename dst_t>
150 *
static_cast<dst_t*
>(dst) =
static_cast<dst_t
>(
151 *
static_cast<const src_t*
>(lhs) != *
static_cast<const src_t*
>(rhs));
163 if (dst_dtype == src_dtype) {
168 #ifdef BUILD_ISPC_MODULE
169 ispc::Indexer ispc_indexer =
indexer.ToISPC();
174 LaunchBinaryEWKernel<scalar_t, scalar_t>(
176 CPULogicalAndElementKernel<scalar_t, scalar_t>,
182 LaunchBinaryEWKernel<scalar_t, scalar_t>(
184 CPULogicalOrElementKernel<scalar_t, scalar_t>,
190 LaunchBinaryEWKernel<scalar_t, scalar_t>(
192 CPULogicalXorElementKernel<scalar_t, scalar_t>,
198 LaunchBinaryEWKernel<scalar_t, scalar_t>(
199 indexer, CPUGtElementKernel<scalar_t, scalar_t>,
201 scalar_t, CPULogicalGtElementKernel,
205 LaunchBinaryEWKernel<scalar_t, scalar_t>(
206 indexer, CPULtElementKernel<scalar_t, scalar_t>,
208 scalar_t, CPULogicalLtElementKernel,
212 LaunchBinaryEWKernel<scalar_t, scalar_t>(
214 CPUGeqElementKernel<scalar_t, scalar_t>,
216 scalar_t, CPULogicalGeqElementKernel,
220 LaunchBinaryEWKernel<scalar_t, scalar_t>(
222 CPULeqElementKernel<scalar_t, scalar_t>,
224 scalar_t, CPULogicalLeqElementKernel,
228 LaunchBinaryEWKernel<scalar_t, scalar_t>(
229 indexer, CPUEqElementKernel<scalar_t, scalar_t>,
231 scalar_t, CPULogicalEqElementKernel,
235 LaunchBinaryEWKernel<scalar_t, scalar_t>(
237 CPUNeqElementKernel<scalar_t, scalar_t>,
239 scalar_t, CPULogicalNeqElementKernel,
250 #ifdef BUILD_ISPC_MODULE
251 ispc::Indexer ispc_indexer =
indexer.ToISPC();
256 LaunchBinaryEWKernel<scalar_t, bool>(
258 CPULogicalAndElementKernel<scalar_t, bool>,
261 CPULogicalAndElementKernel_bool,
265 LaunchBinaryEWKernel<scalar_t, bool>(
267 CPULogicalOrElementKernel<scalar_t, bool>,
270 CPULogicalOrElementKernel_bool,
274 LaunchBinaryEWKernel<scalar_t, bool>(
276 CPULogicalXorElementKernel<scalar_t, bool>,
279 CPULogicalXorElementKernel_bool,
283 LaunchBinaryEWKernel<scalar_t, bool>(
284 indexer, CPUGtElementKernel<scalar_t, bool>,
287 CPULogicalGtElementKernel_bool,
291 LaunchBinaryEWKernel<scalar_t, bool>(
292 indexer, CPULtElementKernel<scalar_t, bool>,
295 CPULogicalLtElementKernel_bool,
299 LaunchBinaryEWKernel<scalar_t, bool>(
300 indexer, CPUGeqElementKernel<scalar_t, bool>,
303 CPULogicalGeqElementKernel_bool,
307 LaunchBinaryEWKernel<scalar_t, bool>(
308 indexer, CPULeqElementKernel<scalar_t, bool>,
311 CPULogicalLeqElementKernel_bool,
315 LaunchBinaryEWKernel<scalar_t, bool>(
316 indexer, CPUEqElementKernel<scalar_t, bool>,
319 CPULogicalEqElementKernel_bool,
323 LaunchBinaryEWKernel<scalar_t, bool>(
324 indexer, CPUNeqElementKernel<scalar_t, bool>,
327 CPULogicalNeqElementKernel_bool,
336 "Boolean op's output type must be boolean or the "
337 "same type as the input.");
345 LaunchBinaryEWKernel<scalar_t, scalar_t>(
346 indexer, CPUMaxElementKernel<scalar_t>);
349 LaunchBinaryEWKernel<scalar_t, scalar_t>(
350 indexer, CPUMinElementKernel<scalar_t>);
358 #ifdef BUILD_ISPC_MODULE
359 ispc::Indexer ispc_indexer =
indexer.ToISPC();
364 LaunchBinaryEWKernel<scalar_t, scalar_t>(
365 indexer, CPUAddElementKernel<scalar_t>,
371 LaunchBinaryEWKernel<scalar_t, scalar_t>(
372 indexer, CPUSubElementKernel<scalar_t>,
378 LaunchBinaryEWKernel<scalar_t, scalar_t>(
379 indexer, CPUMulElementKernel<scalar_t>,
387 LaunchBinaryEWKernel<scalar_t, scalar_t>(
388 indexer, CPUDivElementKernel<scalar_t>);
#define DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(DTYPE,...)
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
#define CLOUDVIEWER_TEMPLATE_VECTORIZED(T, ISPCKernel,...)
static void CPUGtElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPULogicalAndElementKernel(const void *lhs, const void *rhs, void *dst)
void BinaryEWCPU(const Tensor &lhs, const Tensor &rhs, Tensor &dst, BinaryEWOpCode op_code)
const std::unordered_set< BinaryEWOpCode, utility::hash_enum_class > s_boolean_binary_ew_op_codes
static void LaunchBinaryEWKernel(const Indexer &indexer, const element_func_t &element_func)
static void CPUMaxElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUMulElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPULogicalOrElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUEqElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUMinElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPULtElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUGeqElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPULeqElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUNeqElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPULogicalXorElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUDivElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUSubElementKernel(const void *lhs, const void *rhs, void *dst)
static void CPUAddElementKernel(const void *lhs, const void *rhs, void *dst)
void ParallelFor(const Device &device, int64_t n, const func_t &func)
Generic file read and write utility for python interface.