ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
NonZeroCUDA.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <thrust/device_vector.h>
9 #include <thrust/execution_policy.h>
10 #include <thrust/for_each.h>
11 #include <thrust/iterator/zip_iterator.h>
12 
13 #include "core/Indexer.h"
14 #include "core/kernel/NonZero.h"
15 
16 namespace cloudViewer {
17 namespace core {
18 namespace kernel {
19 
20 template <typename T>
21 struct NonZeroFunctor {
22  NonZeroFunctor() {}
23  __host__ __device__ bool operator()(T value) const {
24  return static_cast<float>(value) != 0.0;
25  }
26 };
27 
28 struct FlatIndexTransformFunctor {
29  FlatIndexTransformFunctor(const TensorIterator& iter,
30  int64_t num_non_zeros,
31  int64_t num_dims,
32  const SizeVector& shape)
33  : iter_(iter), num_non_zeros_(num_non_zeros), num_dims_(num_dims) {
34  for (size_t i = 0; i < shape.size(); ++i) {
35  shape_[i] = shape[i];
36  }
37  }
38 
39  template <typename Tuple>
40  __host__ __device__ void operator()(Tuple t) {
41  int64_t i = (int64_t)thrust::get<0>(t);
42  int64_t non_zero_index = (int64_t)thrust::get<1>(t);
43 
44  for (int64_t dim = num_dims_ - 1; dim >= 0; dim--) {
45  *static_cast<int64_t*>(iter_.GetPtr(dim * num_non_zeros_ + i)) =
46  non_zero_index % shape_[dim];
47  non_zero_index = non_zero_index / shape_[dim];
48  }
49  }
50 
51 protected:
52  TensorIterator iter_;
53  int64_t num_non_zeros_;
54  int64_t num_dims_;
55  int64_t shape_[MAX_DIMS];
56 };
57 
58 Tensor NonZeroCUDA(const Tensor& src) {
59  Tensor src_contiguous = src.Contiguous();
60  const int64_t num_elements = src_contiguous.NumElements();
61  const int64_t num_bytes =
62  num_elements * src_contiguous.GetDtype().ByteSize();
63 
64  thrust::counting_iterator<int64_t> index_first(0);
65  thrust::counting_iterator<int64_t> index_last = index_first + num_elements;
66 
67  // Get flattened non-zero indices.
68  thrust::device_vector<int64_t> non_zero_indices(num_elements);
69  DISPATCH_DTYPE_TO_TEMPLATE_WITH_BOOL(src.GetDtype(), [&]() {
70  thrust::device_ptr<const scalar_t> src_ptr(static_cast<const scalar_t*>(
71  src_contiguous.GetBlob()->GetDataPtr()));
72 
73  auto it = thrust::copy_if(index_first, index_last, src_ptr,
74  non_zero_indices.begin(),
75  NonZeroFunctor<scalar_t>());
76  non_zero_indices.resize(thrust::distance(non_zero_indices.begin(), it));
77  });
78 
79  // Transform flattend indices to indices in each dimension.
80  SizeVector shape = src.GetShape();
81  const int64_t num_dims = src.NumDims();
82  const size_t num_non_zeros = non_zero_indices.size();
83 
84  SizeVector result_shape{num_dims, static_cast<int64_t>(num_non_zeros)};
85  Tensor result(result_shape, core::Int64, src.GetDevice());
86  TensorIterator result_iter(result);
87 
88  index_last = index_first + num_non_zeros;
89  thrust::for_each(thrust::device,
90  thrust::make_zip_iterator(thrust::make_tuple(
91  index_first, non_zero_indices.begin())),
92  thrust::make_zip_iterator(thrust::make_tuple(
93  index_last, non_zero_indices.end())),
94  FlatIndexTransformFunctor(result_iter, num_non_zeros,
95  num_dims, shape));
96 
97  return result;
98 }
99 
100 } // namespace kernel
101 } // namespace core
102 } // namespace cloudViewer