ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
ContinuousConvTransposeBackpropFilter.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #define EIGEN_USE_GPU
10 
11 #include <Helper.h>
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
14 
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 
18 using cloudViewer::utility::DivUp;
19 
20 namespace cloudViewer {
21 namespace ml {
22 namespace impl {
23 
24 /// Computes the backprop for the filter of a transpose continuous convolution.
25 ///
26 /// All pointer arguments point to device memory unless stated otherwise.
27 ///
28 /// \tparam TFeat Type for the features and weights
29 /// \tparam TOut Type for the output features
30 /// \tparam TReal Type for point positions and extents
31 /// \tparam TIndex Type for neighbor indexing
32 ///
33 /// \param temp Pointer to temporary memory. If nullptr then the required
34 /// size of temporary memory will be written to \p temp_size and no
35 /// work is done. This function can make use of more memory and
36 /// returns the maximum size that can be used in max_temp_size.
37 ///
38 /// \param temp_size The size of the temporary memory in bytes. This is
39 /// used as an output if temp is nullptr and returns the minimum temp
40 /// size required.
41 ///
42 /// \param max_temp_size This is used as an output if temp is nullptr and
43 /// returns the maximum temp size that can be used.
44 ///
45 /// \param texture_alignment The texture alignment in bytes. This is used
46 /// for allocating segments within the temporary memory.
47 ///
48 /// \param filter_backrop Output array for the computed filter gradient
49 /// with shape [depth,height,witdth, inp channels, out channels]
50 ///
51 /// \param filter_dims The sizes of the filter dimensions. The size of
52 /// filter_dims must be 5. The order is
53 /// [depth, height, width, inp channels, out channels].
54 ///
55 /// \param filter Pointer to the filter values.
56 ///
57 /// \param num_out The number of output points.
58 ///
59 /// \param out_positions The positions of the output points. The shape is
60 /// [num_out, 3].
61 ///
62 /// \param out_importance Optional importance for each output point with
63 /// shape [num_out]. Set to null to disable.
64 ///
65 /// \param num_inp The number of input points.
66 ///
67 /// \param inp_positions The positions of the input points. The shape is
68 /// [num_inp, 3].
69 ///
70 /// \param inp_features The input features with shape
71 /// [num_inp, in_channels].
72 ///
73 /// \param inp_neighbors_importance_sum The sum of the neighbors_importance
74 /// values for each input with shape [num_inp].
75 ///
76 /// \param inp_neighbors_row_splits The prefix sum which defines the start
77 /// and end of the sublists in \p inp_neighbors_index. The size of the
78 /// array is \p num_inp + 1.
79 ///
80 /// \param neighbors_index_size The size of the neighbors_index array.
81 ///
82 /// \param neighbors_index The array with lists of neighbors for each
83 /// output point. The start and end of each sublist is defined by
84 /// \p neighbors_row_splits.
85 ///
86 /// \param neighbors_importance Optional importance for each entry in
87 /// \p neighbors_index. Set to null to disable.
88 ///
89 /// \param neighbors_row_splits The prefix sum which defines the start
90 /// and end of the sublists in \p neighbors_index. The size of the
91 /// array is \p num_out + 1.
92 ///
93 /// \param extents The spatial extents of the filter in coordinate units.
94 /// extents can be a scalar or a 1D array of shape [num_out] or a
95 /// 2D array of shape [num_out,3]. The shape depends on
96 /// \p individual_extent and \p isotropic_extent.
97 ///
98 /// \param offsets A single 3D vector used in the filter coordinate
99 /// computation. The shape is [3].
100 ///
101 /// \param out_features_gradient The gradient from the features with shape
102 /// [num_out, out_channels]
103 ///
104 /// \param interpolation The interpolation mode. Either LINEAR or
105 /// NEAREST_NEIGHBOR.
106 ///
107 /// \param coordinate_mapping The coordinate mapping function. One of
108 /// IDENTITY, BALL_TO_CUBE_RADIAL, BALL_TO_CUBE_VOLUME_PRESERVING.
109 ///
110 /// \param align_corners If true then the voxel centers of the outer voxels
111 /// of the filter array are mapped to the boundary of the filter shape.
112 /// If false then the boundary of the filter array is mapped to the
113 /// boundary of the filter shape.
114 ///
115 /// \param individual_extent If true each output point has an individual
116 /// extent.
117 ///
118 /// \param isotropic_extent If true each then the extent is isotropic for
119 /// each output point.
120 ///
121 /// \param normalize If true then the result is normalized either by the
122 /// number of points (neighbors_importance is null) or by the sum of
123 /// the respective values in neighbors_importance.
124 ///
125 template <class TFeat, class TOut, class TReal, class TIndex>
126 void CConvTransposeBackpropFilterCUDA(const cudaStream_t& stream,
127  void* temp,
128  size_t& temp_size,
129  size_t& max_temp_size,
130  int texture_alignment,
131  TOut* filter_backprop,
132  const std::vector<int>& filter_dims,
133  TIndex num_out,
134  const TReal* out_positions,
135  const TFeat* out_importance,
136  TIndex num_inp,
137  const TReal* inp_positions,
138  const TFeat* inp_features,
139  const TFeat* inp_neighbors_importance_sum,
140  const int64_t* inp_neighbors_row_splits,
141  size_t neighbors_index_size,
142  const TIndex* neighbors_index,
143  const TFeat* neighbors_importance,
144  const int64_t* neighbors_row_splits,
145  const TReal* extents,
146  const TReal* offsets,
147  const TFeat* out_features_gradient,
148  InterpolationMode interpolation,
149  CoordinateMapping coordinate_mapping,
150  bool align_corners,
151  bool individual_extent,
152  bool isotropic_extent,
153  bool normalize) {
154  const bool get_temp_size = !temp;
155 
156  if (get_temp_size) {
157  temp = (char*)1; // worst case alignment
158  temp_size = std::numeric_limits<int64_t>::max();
159  }
160 
161  MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
162 
163  const int in_channels = filter_dims[filter_dims.size() - 2];
164  const int out_channels = filter_dims[filter_dims.size() - 1];
165 
166  int spatial_filter_size = 1;
167  for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
168 
169  // this defines how much temporary storage we need at least
170  // we want to allocate memory for at least 32 output points.
171  const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
172  const size_t max_num_cols_per_run = num_out;
173  size_t bytes_per_column =
174  sizeof(TFeat) * (spatial_filter_size * in_channels);
175  if (out_importance) bytes_per_column += sizeof(TFeat) * out_channels;
176  const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
177  const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
178 
179  if (get_temp_size) {
180  std::pair<char*, size_t> tmp =
181  mem_temp.Alloc<char>(min_temp_size_bytes);
182  temp_size = mem_temp.MaxUsed();
183  mem_temp.Free(tmp);
184  mem_temp.Alloc<char>(max_temp_size_bytes);
185  max_temp_size = mem_temp.MaxUsed();
186  return;
187  }
188 
189  std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
190 
191  const size_t num_cols_per_run =
192  std::min(mem_columns.second / bytes_per_column, size_t(num_out));
193 
194  if (mem_columns.second < min_temp_size_bytes) {
195  std::stringstream ss;
196  ss << "temp is too small " << mem_columns.second
197  << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
198  throw std::runtime_error(ss.str());
199  }
200 
201  cudaMemsetAsync(
202  filter_backprop, 0,
203  sizeof(TOut) * spatial_filter_size * in_channels * out_channels,
204  stream);
205 
206  typedef cutlass::gemm::SgemmTraits<
207  cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
208  cutlass::MatrixLayout::kRowMajor, // layout of B matrix
209  cutlass::Shape<8, 64, 64> // threadblock tile size
210  >
211  GemmTraits;
212 
213  typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
214 
215  TFeat* columns = (TFeat*)mem_columns.first;
216  TFeat* gradient = ((TFeat*)mem_columns.first) +
217  num_cols_per_run * spatial_filter_size * in_channels;
218 
219  // if we cannot process all data at once we need multiple runs
220  size_t num_runs = DivUp(num_out, num_cols_per_run);
221  for (size_t run_i = 0; run_i < num_runs; ++run_i) {
222  const TIndex begin_idx = run_i * num_cols_per_run;
223  const TIndex end_idx =
224  std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
225  const size_t num_cols_this_run = end_idx - begin_idx;
226 
227  if (out_importance) {
228  MultiplyAndCopyColumns(
229  stream, out_channels, num_cols_this_run, gradient,
230  out_features_gradient +
231  (run_i * num_cols_per_run * out_channels),
232  out_importance + (run_i * num_cols_per_run));
233  } else {
234  gradient = const_cast<TFeat*>(
235  out_features_gradient +
236  (run_i * num_cols_per_run * out_channels));
237  }
238 
239  FillColumnTranspose<TFeat, TReal, TIndex>(
240  stream, columns, in_channels, begin_idx, end_idx, num_out,
241  out_positions, num_inp, inp_positions, inp_features,
242  inp_neighbors_importance_sum, inp_neighbors_row_splits,
243  neighbors_index_size, neighbors_index, neighbors_importance,
244  neighbors_row_splits, extents, offsets, filter_dims,
245  interpolation, coordinate_mapping, align_corners,
246  individual_extent, isotropic_extent, normalize);
247 
248  typename Gemm::Params params;
249  // C is MxN
250  // B is KxN
251  // A is MxK
252  int m = out_channels;
253  int k = num_cols_this_run;
254  int n = spatial_filter_size * in_channels;
255  float alpha = 1;
256  const float* const A = gradient;
257  int lda = m;
258  const float* const B = columns;
259  int ldb = n;
260  float beta = 1;
261  float* C = filter_backprop;
262  int ldc = m;
263 
264  int result =
265  params.initialize(m, // GEMM M dimension
266  n, // GEMM N dimension
267  k, // GEMM K dimension
268  alpha, // scalar alpha
269  A, // matrix A operand
270  lda,
271  B, // matrix B operand
272  ldb,
273  beta, // scalar beta
274  C, // source matrix C
275  ldc,
276  C, // destination matrix C (may be different
277  ldc);
278 
279  if (result) {
280  throw std::runtime_error(
281  "Failed to initialize CUTLASS Gemm::Params object.");
282  }
283 
284  Gemm::launch(params, stream);
285  }
286 }
287 
288 } // namespace impl
289 } // namespace ml
290 } // namespace cloudViewer