ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
SparseConvTranspose.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #define EIGEN_USE_GPU
10 
11 #include <Helper.h>
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
14 
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
18 
19 using cloudViewer::utility::DivUp;
20 
21 namespace cloudViewer {
22 namespace ml {
23 namespace impl {
24 
25 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
26 void SparseConvTransposeComputeFeaturesCUDA(
27  const cudaStream_t& stream,
28  void* temp,
29  size_t& temp_size,
30  size_t& max_temp_size,
31  int texture_alignment,
32  TOut* out_features,
33  const std::vector<int>& filter_dims,
34  const TFeat* filter,
35  TIndex num_out,
36  const TFeat* out_importance,
37  TIndex num_inp,
38  const TFeat* inp_features,
39  const TFeat* inp_neighbors_importance_sum,
40  const int64_t* inp_neighbors_prefix_sum,
41  size_t neighbors_index_size,
42  const TIndex* neighbors_index,
43  const TKernelIndex* neighbors_kernel_index,
44  const TFeat* neighbors_importance,
45  const int64_t* neighbors_row_splits,
46  bool normalize) {
47  const bool get_temp_size = !temp;
48 
49  if (get_temp_size) {
50  temp = (char*)1; // worst case alignment
51  temp_size = std::numeric_limits<int64_t>::max();
52  }
53 
54  MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
55 
56  const int in_channels = filter_dims[filter_dims.size() - 2];
57  const int out_channels = filter_dims[filter_dims.size() - 1];
58 
59  int num_kernel_elements = 1;
60  for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
61  num_kernel_elements *= filter_dims[i];
62 
63  // this defines how much temporary storage we need at least.
64  // we want to allocate memory for at least 32 output points.
65  const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
66  const size_t max_num_cols_per_run = num_out;
67  const size_t bytes_per_column =
68  sizeof(TFeat) * (num_kernel_elements * in_channels);
69  const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
70  const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
71 
72  if (get_temp_size) {
73  std::pair<char*, size_t> tmp =
74  mem_temp.Alloc<char>(min_temp_size_bytes);
75  temp_size = mem_temp.MaxUsed();
76  mem_temp.Free(tmp);
77  mem_temp.Alloc<char>(max_temp_size_bytes);
78  max_temp_size = mem_temp.MaxUsed();
79  return;
80  }
81 
82  // Request segment using all of the temporary memory
83  std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
84 
85  if (mem_columns.second < min_temp_size_bytes) {
86  std::stringstream ss;
87  ss << "temp is too small " << mem_columns.second
88  << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
89  throw std::runtime_error(ss.str());
90  }
91 
92  // init output
93  cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
94  stream);
95 
96  size_t num_cols_per_run =
97  std::min(mem_columns.second / bytes_per_column, size_t(num_out));
98 
99  typedef cutlass::gemm::SgemmTraits<
100  cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
101  cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
102  cutlass::Shape<8, 64, 64> // threadblock tile size
103  >
104  GemmTraits;
105 
106  typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
107 
108  TFeat* columns = (TFeat*)mem_columns.first;
109 
110  // if we cannot process all data at once we need multiple runs
111  size_t num_runs = DivUp(num_out, num_cols_per_run);
112  for (size_t run_i = 0; run_i < num_runs; ++run_i) {
113  const TIndex begin_idx = run_i * num_cols_per_run;
114  const TIndex end_idx =
115  std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
116  const size_t num_cols_this_run = end_idx - begin_idx;
117 
118  FillColumnTranspose<TFeat, TIndex, TKernelIndex>(
119  stream, columns, in_channels, begin_idx, end_idx, num_out,
120  num_inp, inp_features, inp_neighbors_importance_sum,
121  inp_neighbors_prefix_sum, neighbors_index_size, neighbors_index,
122  neighbors_kernel_index, neighbors_importance,
123  neighbors_row_splits, num_kernel_elements, normalize);
124 
125  typename Gemm::Params params;
126  // C is MxN
127  // B is KxN
128  // A is MxK
129  int m = out_channels;
130  int k = num_kernel_elements * in_channels;
131  int n = num_cols_this_run;
132  float alpha = 1;
133  const float* const A = filter;
134  int lda = m;
135  const float* const B = columns;
136  int ldb = k;
137  float beta = 1;
138  float* C = out_features + (run_i * num_cols_per_run * out_channels);
139  int ldc = m;
140 
141  int result =
142  params.initialize(m, // GEMM M dimension
143  n, // GEMM N dimension
144  k, // GEMM K dimension
145  alpha, // scalar alpha
146  A, // matrix A operand
147  lda,
148  B, // matrix B operand
149  ldb,
150  beta, // scalar beta
151  C, // source matrix C
152  ldc,
153  C, // destination matrix C (may be different
154  // memory than source C matrix)
155  ldc);
156 
157  if (result) {
158  throw std::runtime_error(
159  "Failed to initialize CUTLASS Gemm::Params object.");
160  }
161 
162  Gemm::launch(params, stream);
163  }
164 
165  if (out_importance) {
166  MultiplyColumns(stream, out_channels, num_out, out_features,
167  out_importance);
168  }
169 }
170 
171 } // namespace impl
172 } // namespace ml
173 } // namespace cloudViewer