ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
SparseConvTransposeBackpropFilter.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #define EIGEN_USE_GPU
10 
11 #include <Helper.h>
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
14 
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
18 
19 using cloudViewer::utility::DivUp;
20 
21 namespace cloudViewer {
22 namespace ml {
23 namespace impl {
24 
25 /// Computes the backprop for the filter of a transpose sparse convolution.
26 ///
27 /// All pointer arguments point to device memory unless stated otherwise.
28 ///
29 /// \param temp Pointer to temporary memory. If nullptr then the required
30 /// size of temporary memory will be written to \p temp_size and no
31 /// work is done. This function can make use of more memory and
32 /// returns the maximum size that can be used in max_temp_size.
33 ///
34 /// \param temp_size The size of the temporary memory in bytes. This is
35 /// used as an output if temp is nullptr and returns the minimum temp
36 /// size required.
37 ///
38 /// \param max_temp_size This is used as an output if temp is nullptr and
39 /// returns the maximum temp size that can be used.
40 ///
41 /// \param texture_alignment The texture alignment in bytes. This is used
42 /// for allocating segments within the temporary memory.
43 ///
44 /// \param filter_backrop Output array for the computed filter gradient
45 /// with shape [depth,height,witdth, inp channels, out channels]
46 ///
47 /// \param filter_dims The sizes of the filter dimensions. The size of
48 /// filter_dims must be >=3. The order is
49 /// [num kernel elements, inp channels, out channels].
50 ///
51 /// \param filter Pointer to the filter values.
52 ///
53 /// \param num_out The number of output points.
54 ///
55 /// \param out_importance Optional importance for each output point with
56 /// shape [num_out]. Set to null to disable.
57 ///
58 /// \param num_inp The number of input points.
59 ///
60 /// \param inp_features The input features with shape
61 /// [num_inp, in_channels].
62 ///
63 /// \param inp_neighbors_importance_sum The sum of the neighbors_importance
64 /// values for each input with shape [num_inp].
65 ///
66 /// \param inp_neighbors_row_splits The prefix sum which defines the start
67 /// and end of the sublists in \p inp_neighbors_index. The size of the
68 /// array is \p num_inp + 1.
69 ///
70 /// \param neighbors_index_size The size of the neighbors_index array.
71 ///
72 /// \param neighbors_index The array with lists of neighbors for each
73 /// output point. The start and end of each sublist is defined by
74 /// \p neighbors_row_splits.
75 ///
76 /// \param neighbors_kernel_index Defines which kernel element to use for
77 /// each neighbor. This array has the same length as \p neighbors_index.
78 ///
79 /// \param neighbors_importance Optional importance for each entry in
80 /// \p neighbors_index. Set to null to disable.
81 ///
82 /// \param neighbors_row_splits The prefix sum which defines the start
83 /// and end of the sublists in \p neighbors_index. The size of the
84 /// array is \p num_out + 1.
85 ///
86 /// \param out_features_gradient The gradient from the features with shape
87 /// [num_out, out_channels]
88 ///
89 /// \param normalize If true then the result is normalized either by the
90 /// number of points (neighbors_importance is null) or by the sum of
91 /// the respective values in neighbors_importance.
92 ///
93 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
94 void SparseConvTransposeBackpropFilterCUDA(
95  const cudaStream_t& stream,
96  void* temp,
97  size_t& temp_size,
98  size_t& max_temp_size,
99  int texture_alignment,
100  TOut* filter_backprop,
101  const std::vector<int>& filter_dims,
102  TIndex num_out,
103  const TFeat* out_importance,
104  TIndex num_inp,
105  const TFeat* inp_features,
106  const TFeat* inp_neighbors_importance_sum,
107  const int64_t* inp_neighbors_row_splits,
108  size_t neighbors_index_size,
109  const TIndex* neighbors_index,
110  const TKernelIndex* neighbors_kernel_index,
111  const TFeat* neighbors_importance,
112  const int64_t* neighbors_row_splits,
113  const TFeat* out_features_gradient,
114  bool normalize) {
115  const bool get_temp_size = !temp;
116 
117  if (get_temp_size) {
118  temp = (char*)1; // worst case alignment
119  temp_size = std::numeric_limits<int64_t>::max();
120  }
121 
122  MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
123 
124  const int in_channels = filter_dims[filter_dims.size() - 2];
125  const int out_channels = filter_dims[filter_dims.size() - 1];
126 
127  int num_kernel_elements = 1;
128  for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
129  num_kernel_elements *= filter_dims[i];
130 
131  // this defines how much temporary storage we need at least
132  // we want to allocate memory for at least 32 output points.
133  const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
134  const size_t max_num_cols_per_run = num_out;
135  size_t bytes_per_column =
136  sizeof(TFeat) * (num_kernel_elements * in_channels);
137  if (out_importance) bytes_per_column += sizeof(TFeat) * out_channels;
138  const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
139  const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
140 
141  if (get_temp_size) {
142  std::pair<char*, size_t> tmp =
143  mem_temp.Alloc<char>(min_temp_size_bytes);
144  temp_size = mem_temp.MaxUsed();
145  mem_temp.Free(tmp);
146  mem_temp.Alloc<char>(max_temp_size_bytes);
147  max_temp_size = mem_temp.MaxUsed();
148  return;
149  }
150 
151  std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
152 
153  const size_t num_cols_per_run =
154  std::min(mem_columns.second / bytes_per_column, size_t(num_out));
155 
156  if (mem_columns.second < min_temp_size_bytes) {
157  std::stringstream ss;
158  ss << "temp is too small " << mem_columns.second
159  << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
160  throw std::runtime_error(ss.str());
161  }
162 
163  cudaMemsetAsync(
164  filter_backprop, 0,
165  sizeof(TOut) * num_kernel_elements * in_channels * out_channels,
166  stream);
167 
168  typedef cutlass::gemm::SgemmTraits<
169  cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
170  cutlass::MatrixLayout::kRowMajor, // layout of B matrix
171  cutlass::Shape<8, 64, 64> // threadblock tile size
172  >
173  GemmTraits;
174 
175  typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
176 
177  TFeat* columns = (TFeat*)mem_columns.first;
178  TFeat* gradient = ((TFeat*)mem_columns.first) +
179  num_cols_per_run * num_kernel_elements * in_channels;
180 
181  // if we cannot process all data at once we need multiple runs
182  size_t num_runs = DivUp(num_out, num_cols_per_run);
183  for (size_t run_i = 0; run_i < num_runs; ++run_i) {
184  const TIndex begin_idx = run_i * num_cols_per_run;
185  const TIndex end_idx =
186  std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
187  const size_t num_cols_this_run = end_idx - begin_idx;
188 
189  if (out_importance) {
190  MultiplyAndCopyColumns(
191  stream, out_channels, num_cols_this_run, gradient,
192  out_features_gradient +
193  (run_i * num_cols_per_run * out_channels),
194  out_importance + (run_i * num_cols_per_run));
195  } else {
196  gradient = const_cast<TFeat*>(
197  out_features_gradient +
198  (run_i * num_cols_per_run * out_channels));
199  }
200 
201  FillColumnTranspose<TFeat, TIndex>(
202  stream, columns, in_channels, begin_idx, end_idx, num_out,
203  num_inp, inp_features, inp_neighbors_importance_sum,
204  inp_neighbors_row_splits, neighbors_index_size, neighbors_index,
205  neighbors_kernel_index, neighbors_importance,
206  neighbors_row_splits, num_kernel_elements, normalize);
207 
208  typename Gemm::Params params;
209  // C is MxN
210  // B is KxN
211  // A is MxK
212  int m = out_channels;
213  int k = num_cols_this_run;
214  int n = num_kernel_elements * in_channels;
215  float alpha = 1;
216  const float* const A = gradient;
217  int lda = m;
218  const float* const B = columns;
219  int ldb = n;
220  float beta = 1;
221  float* C = filter_backprop;
222  int ldc = m;
223 
224  int result =
225  params.initialize(m, // GEMM M dimension
226  n, // GEMM N dimension
227  k, // GEMM K dimension
228  alpha, // scalar alpha
229  A, // matrix A operand
230  lda,
231  B, // matrix B operand
232  ldb,
233  beta, // scalar beta
234  C, // source matrix C
235  ldc,
236  C, // destination matrix C (may be different
237  ldc);
238 
239  if (result) {
240  throw std::runtime_error(
241  "Failed to initialize CUTLASS Gemm::Params object.");
242  }
243 
244  Gemm::launch(params, stream);
245  }
246 }
247 
248 } // namespace impl
249 } // namespace ml
250 } // namespace cloudViewer