1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "cloudViewer/core/CUDAUtils.h"
9 #include "core/AdvancedIndexing.h"
10 #include "core/Dispatch.h"
11 #include "core/Indexer.h"
12 #include "core/ParallelFor.h"
13 #include "core/Tensor.h"
14 #include "core/kernel/IndexGetSet.h"
16 namespace cloudViewer {
20 template <typename func_t>
21 void LaunchAdvancedIndexerKernel(const Device& device,
22 const AdvancedIndexer& indexer,
23 const func_t& element_kernel) {
24 CLOUDVIEWER_ASSERT_HOST_DEVICE_LAMBDA(func_t);
25 auto element_func = [=] CLOUDVIEWER_HOST_DEVICE(int64_t i) {
26 element_kernel(indexer.GetInputPtr(i), indexer.GetOutputPtr(i));
28 ParallelFor(device, indexer.NumWorkloads(), element_func);
29 CLOUDVIEWER_GET_LAST_CUDA_ERROR("LaunchAdvancedIndexerKernel failed.");
32 template <typename scalar_t>
33 static CLOUDVIEWER_HOST_DEVICE void CUDACopyElementKernel(const void* src,
35 *static_cast<scalar_t*>(dst) = *static_cast<const scalar_t*>(src);
38 static CLOUDVIEWER_HOST_DEVICE void CUDACopyObjectElementKernel(
39 const void* src, void* dst, int64_t object_byte_size) {
40 const char* src_bytes = static_cast<const char*>(src);
41 char* dst_bytes = static_cast<char*>(dst);
42 for (int i = 0; i < object_byte_size; ++i) {
43 dst_bytes[i] = src_bytes[i];
47 void IndexGetCUDA(const Tensor& src,
49 const std::vector<Tensor>& index_tensors,
50 const SizeVector& indexed_shape,
51 const SizeVector& indexed_strides) {
52 Dtype dtype = src.GetDtype();
53 AdvancedIndexer ai(src, dst, index_tensors, indexed_shape, indexed_strides,
54 AdvancedIndexer::AdvancedIndexerMode::GET);
56 if (dtype.IsObject()) {
57 int64_t object_byte_size = dtype.ByteSize();
58 LaunchAdvancedIndexerKernel(
60 [=] CLOUDVIEWER_HOST_DEVICE(const void* src, void* dst) {
61 CUDACopyObjectElementKernel(src, dst, object_byte_size);
64 DISPATCH_DTYPE_TO_TEMPLATE(dtype, [&]() {
65 LaunchAdvancedIndexerKernel(
67 // Need to wrap as extended CUDA lambda function
68 [] CLOUDVIEWER_HOST_DEVICE(const void* src, void* dst) {
69 CUDACopyElementKernel<scalar_t>(src, dst);
75 void IndexSetCUDA(const Tensor& src,
77 const std::vector<Tensor>& index_tensors,
78 const SizeVector& indexed_shape,
79 const SizeVector& indexed_strides) {
80 Dtype dtype = src.GetDtype();
81 AdvancedIndexer ai(src, dst, index_tensors, indexed_shape, indexed_strides,
82 AdvancedIndexer::AdvancedIndexerMode::SET);
84 if (dtype.IsObject()) {
85 int64_t object_byte_size = dtype.ByteSize();
86 LaunchAdvancedIndexerKernel(
88 [=] CLOUDVIEWER_HOST_DEVICE(const void* src, void* dst) {
89 CUDACopyObjectElementKernel(src, dst, object_byte_size);
92 DISPATCH_DTYPE_TO_TEMPLATE(dtype, [&]() {
93 LaunchAdvancedIndexerKernel(
95 // Need to wrap as extended CUDA lambda function
96 [] CLOUDVIEWER_HOST_DEVICE(const void* src, void* dst) {
97 CUDACopyElementKernel<scalar_t>(src, dst);
103 } // namespace kernel
105 } // namespace cloudViewer