1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include <thrust/device_vector.h>
9 #include <thrust/execution_policy.h>
10 #include <thrust/for_each.h>
11 #include <thrust/iterator/zip_iterator.h>
13 #include "core/Indexer.h"
14 #include "core/kernel/NonZero.h"
16 namespace cloudViewer {
21 struct NonZeroFunctor {
23 __host__ __device__ bool operator()(T value) const {
24 return static_cast<float>(value) != 0.0;
28 struct FlatIndexTransformFunctor {
29 FlatIndexTransformFunctor(const TensorIterator& iter,
30 int64_t num_non_zeros,
32 const SizeVector& shape)
33 : iter_(iter), num_non_zeros_(num_non_zeros), num_dims_(num_dims) {
34 for (size_t i = 0; i < shape.size(); ++i) {
39 template <typename Tuple>
40 __host__ __device__ void operator()(Tuple t) {
41 int64_t i = (int64_t)thrust::get<0>(t);
42 int64_t non_zero_index = (int64_t)thrust::get<1>(t);
44 for (int64_t dim = num_dims_ - 1; dim >= 0; dim--) {
45 *static_cast<int64_t*>(iter_.GetPtr(dim * num_non_zeros_ + i)) =
46 non_zero_index % shape_[dim];
47 non_zero_index = non_zero_index / shape_[dim];
53 int64_t num_non_zeros_;
55 int64_t shape_[MAX_DIMS];
58 Tensor NonZeroCUDA(const Tensor& src) {
59 Tensor src_contiguous = src.Contiguous();
60 const int64_t num_elements = src_contiguous.NumElements();
61 const int64_t num_bytes =
62 num_elements * src_contiguous.GetDtype().ByteSize();
64 thrust::counting_iterator<int64_t> index_first(0);
65 thrust::counting_iterator<int64_t> index_last = index_first + num_elements;
67 // Get flattened non-zero indices.
68 thrust::device_vector<int64_t> non_zero_indices(num_elements);
69 DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src.GetDtype(), [&]() {
70 thrust::device_ptr<const scalar_t> src_ptr(static_cast<const scalar_t*>(
71 src_contiguous.GetBlob()->GetDataPtr()));
73 auto it = thrust::copy_if(index_first, index_last, src_ptr,
74 non_zero_indices.begin(),
75 NonZeroFunctor<scalar_t>());
76 non_zero_indices.resize(thrust::distance(non_zero_indices.begin(), it));
79 // Transform flattend indices to indices in each dimension.
80 SizeVector shape = src.GetShape();
81 const int64_t num_dims = src.NumDims();
82 const size_t num_non_zeros = non_zero_indices.size();
84 SizeVector result_shape{num_dims, static_cast<int64_t>(num_non_zeros)};
85 Tensor result(result_shape, core::Int64, src.GetDevice());
86 TensorIterator result_iter(result);
88 index_last = index_first + num_non_zeros;
89 thrust::for_each(thrust::device,
90 thrust::make_zip_iterator(thrust::make_tuple(
91 index_first, non_zero_indices.begin())),
92 thrust::make_zip_iterator(thrust::make_tuple(
93 index_last, non_zero_indices.end())),
94 FlatIndexTransformFunctor(result_iter, num_non_zeros,
100 } // namespace kernel
102 } // namespace cloudViewer