1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "cloudViewer/core/CUDAUtils.h"
9 #include "core/Dispatch.h"
10 #include "core/Indexer.h"
11 #include "core/ParallelFor.h"
12 #include "core/Tensor.h"
13 #include "core/kernel/BinaryEW.h"
15 namespace cloudViewer {
19 // Cannot be a static function since on Windows a function enclosing
20 // __host__ __device__ lambda function must have external linkage.
21 template <typename src_t, typename dst_t, typename func_t>
22 void LaunchBinaryEWKernel(const Device& device,
23 const Indexer& indexer,
24 const func_t& element_kernel) {
25 CLOUDVIEWER_ASSERT_HOST_DEVICE_LAMBDA(func_t);
26 auto element_func = [=] CLOUDVIEWER_HOST_DEVICE(int64_t i) {
27 element_kernel(indexer.GetInputPtr<src_t>(0, i),
28 indexer.GetInputPtr<src_t>(1, i),
29 indexer.GetOutputPtr<dst_t>(i));
31 ParallelFor(device, indexer.NumWorkloads(), element_func);
32 CLOUDVIEWER_GET_LAST_CUDA_ERROR("LaunchBinaryEWKernel failed.");
35 template <typename scalar_t>
36 static CLOUDVIEWER_HOST_DEVICE void CUDAMaxElementKernel(const void* lhs,
39 *static_cast<scalar_t*>(dst) = max(*static_cast<const scalar_t*>(lhs),
40 *static_cast<const scalar_t*>(rhs));
43 template <typename scalar_t>
44 static CLOUDVIEWER_HOST_DEVICE void CUDAMinElementKernel(const void* lhs,
47 *static_cast<scalar_t*>(dst) = min(*static_cast<const scalar_t*>(lhs),
48 *static_cast<const scalar_t*>(rhs));
51 template <typename scalar_t>
52 static CLOUDVIEWER_HOST_DEVICE void CUDAAddElementKernel(const void* lhs,
55 *static_cast<scalar_t*>(dst) = *static_cast<const scalar_t*>(lhs) +
56 *static_cast<const scalar_t*>(rhs);
59 template <typename scalar_t>
60 static CLOUDVIEWER_HOST_DEVICE void CUDASubElementKernel(const void* lhs,
63 *static_cast<scalar_t*>(dst) = *static_cast<const scalar_t*>(lhs) -
64 *static_cast<const scalar_t*>(rhs);
67 template <typename scalar_t>
68 static CLOUDVIEWER_HOST_DEVICE void CUDAMulElementKernel(const void* lhs,
71 *static_cast<scalar_t*>(dst) = *static_cast<const scalar_t*>(lhs) *
72 *static_cast<const scalar_t*>(rhs);
75 template <typename scalar_t>
76 static CLOUDVIEWER_HOST_DEVICE void CUDADivElementKernel(const void* lhs,
79 *static_cast<scalar_t*>(dst) = *static_cast<const scalar_t*>(lhs) /
80 *static_cast<const scalar_t*>(rhs);
83 template <typename src_t, typename dst_t>
84 static CLOUDVIEWER_HOST_DEVICE void CUDALogicalAndElementKernel(const void* lhs,
87 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
88 static_cast<bool>(*static_cast<const src_t*>(lhs)) &&
89 static_cast<bool>(*static_cast<const src_t*>(rhs)));
92 template <typename src_t, typename dst_t>
93 static CLOUDVIEWER_HOST_DEVICE void CUDALogicalOrElementKernel(const void* lhs,
96 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
97 static_cast<bool>(*static_cast<const src_t*>(lhs)) ||
98 static_cast<bool>(*static_cast<const src_t*>(rhs)));
101 template <typename src_t, typename dst_t>
102 static CLOUDVIEWER_HOST_DEVICE void CUDALogicalXorElementKernel(const void* lhs,
105 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
106 static_cast<bool>(*static_cast<const src_t*>(lhs)) !=
107 static_cast<bool>(*static_cast<const src_t*>(rhs)));
110 template <typename src_t, typename dst_t>
111 static CLOUDVIEWER_HOST_DEVICE void CUDAGtElementKernel(const void* lhs,
114 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
115 *static_cast<const src_t*>(lhs) > *static_cast<const src_t*>(rhs));
118 template <typename src_t, typename dst_t>
119 static CLOUDVIEWER_HOST_DEVICE void CUDALtElementKernel(const void* lhs,
122 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
123 *static_cast<const src_t*>(lhs) < *static_cast<const src_t*>(rhs));
126 template <typename src_t, typename dst_t>
127 static void CLOUDVIEWER_HOST_DEVICE CUDAGeqElementKernel(const void* lhs,
130 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
131 *static_cast<const src_t*>(lhs) >= *static_cast<const src_t*>(rhs));
134 template <typename src_t, typename dst_t>
135 static void CLOUDVIEWER_HOST_DEVICE CUDALeqElementKernel(const void* lhs,
138 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
139 *static_cast<const src_t*>(lhs) <= *static_cast<const src_t*>(rhs));
142 template <typename src_t, typename dst_t>
143 static void CLOUDVIEWER_HOST_DEVICE CUDAEqElementKernel(const void* lhs,
146 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
147 *static_cast<const src_t*>(lhs) == *static_cast<const src_t*>(rhs));
150 template <typename src_t, typename dst_t>
151 static void CLOUDVIEWER_HOST_DEVICE CUDANeqElementKernel(const void* lhs,
154 *static_cast<dst_t*>(dst) = static_cast<dst_t>(
155 *static_cast<const src_t*>(lhs) != *static_cast<const src_t*>(rhs));
158 template <typename src_t, typename dst_t>
159 void LaunchBoolBinaryEWCUDAKernel(const Tensor& lhs,
162 BinaryEWOpCode op_code,
163 const Indexer& indexer) {
164 Device device = lhs.GetDevice();
166 case BinaryEWOpCode::LogicalAnd:
167 LaunchBinaryEWKernel<src_t, dst_t>(
169 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
171 CUDALogicalAndElementKernel<src_t, dst_t>(lhs, rhs,
175 case BinaryEWOpCode::LogicalOr:
176 LaunchBinaryEWKernel<src_t, dst_t>(
178 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
180 CUDALogicalOrElementKernel<src_t, dst_t>(lhs, rhs, dst);
183 case BinaryEWOpCode::LogicalXor:
184 LaunchBinaryEWKernel<src_t, dst_t>(
186 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
188 CUDALogicalXorElementKernel<src_t, dst_t>(lhs, rhs,
192 case BinaryEWOpCode::Gt:
193 LaunchBinaryEWKernel<src_t, dst_t>(
195 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
197 CUDAGtElementKernel<src_t, dst_t>(lhs, rhs, dst);
200 case BinaryEWOpCode::Lt:
201 LaunchBinaryEWKernel<src_t, dst_t>(
203 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
205 CUDALtElementKernel<src_t, dst_t>(lhs, rhs, dst);
208 case BinaryEWOpCode::Ge:
209 LaunchBinaryEWKernel<src_t, dst_t>(
211 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
213 CUDAGeqElementKernel<src_t, dst_t>(lhs, rhs, dst);
216 case BinaryEWOpCode::Le:
217 LaunchBinaryEWKernel<src_t, dst_t>(
219 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
221 CUDALeqElementKernel<src_t, dst_t>(lhs, rhs, dst);
224 case BinaryEWOpCode::Eq:
225 LaunchBinaryEWKernel<src_t, dst_t>(
227 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
229 CUDAEqElementKernel<src_t, dst_t>(lhs, rhs, dst);
232 case BinaryEWOpCode::Ne:
233 LaunchBinaryEWKernel<src_t, dst_t>(
235 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs, void* rhs,
237 CUDANeqElementKernel<src_t, dst_t>(lhs, rhs, dst);
245 void BinaryEWCUDA(const Tensor& lhs,
248 BinaryEWOpCode op_code) {
249 // It has been checked that
250 // - lhs, rhs, dst are all in the same CUDA device
251 // - lhs, rhs have the same dtype, dst also has the same dtype or is boolean
252 Device src_device = lhs.GetDevice();
253 Dtype src_dtype = lhs.GetDtype();
254 Dtype dst_dtype = dst.GetDtype();
256 CUDAScopedDevice scoped_device(src_device);
258 if (s_boolean_binary_ew_op_codes.find(op_code) !=
259 s_boolean_binary_ew_op_codes.end()) {
260 DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src_dtype, [&]() {
261 if (dst_dtype == src_dtype) {
262 // Inplace boolean op's output type is the same as the
263 // input. e.g. np.logical_and(a, b, out=a), where a, b are
265 Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
266 LaunchBoolBinaryEWCUDAKernel<scalar_t, scalar_t>(
267 lhs, rhs, dst, op_code, indexer);
268 } else if (dst_dtype == core::Bool) {
269 // By default, output is boolean type.
270 Indexer indexer({lhs, rhs}, dst,
271 DtypePolicy::INPUT_SAME_OUTPUT_BOOL);
273 LaunchBoolBinaryEWCUDAKernel<scalar_t, bool>(lhs, rhs, dst,
277 "Boolean op's output type must be boolean or the "
278 "same type as the input.");
281 } else if (op_code == BinaryEWOpCode::Maximum ||
282 op_code == BinaryEWOpCode::Minimum) {
283 Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
284 DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src_dtype, [&]() {
286 case BinaryEWOpCode::Maximum:
287 LaunchBinaryEWKernel<scalar_t, scalar_t>(
289 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
290 void* rhs, void* dst) {
291 CUDAMaxElementKernel<scalar_t>(lhs, rhs, dst);
294 case BinaryEWOpCode::Minimum:
295 LaunchBinaryEWKernel<scalar_t, scalar_t>(
297 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
298 void* rhs, void* dst) {
299 CUDAMinElementKernel<scalar_t>(lhs, rhs, dst);
307 Indexer indexer({lhs, rhs}, dst, DtypePolicy::ALL_SAME);
308 DISPATCH_DTYPE_TO_TEMPLATE(src_dtype, [&]() {
310 case BinaryEWOpCode::Add:
311 LaunchBinaryEWKernel<scalar_t, scalar_t>(
313 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
314 void* rhs, void* dst) {
315 CUDAAddElementKernel<scalar_t>(lhs, rhs, dst);
318 case BinaryEWOpCode::Sub:
319 LaunchBinaryEWKernel<scalar_t, scalar_t>(
321 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
322 void* rhs, void* dst) {
323 CUDASubElementKernel<scalar_t>(lhs, rhs, dst);
326 case BinaryEWOpCode::Mul:
327 LaunchBinaryEWKernel<scalar_t, scalar_t>(
329 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
330 void* rhs, void* dst) {
331 CUDAMulElementKernel<scalar_t>(lhs, rhs, dst);
334 case BinaryEWOpCode::Div:
335 LaunchBinaryEWKernel<scalar_t, scalar_t>(
337 [] CLOUDVIEWER_HOST_DEVICE(const void* lhs,
338 void* rhs, void* dst) {
339 CUDADivElementKernel<scalar_t>(lhs, rhs, dst);
349 } // namespace kernel
351 } // namespace cloudViewer