1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
18 using cloudViewer::utility::DivUp;
20 namespace cloudViewer {
24 /// Computes the backprop for the filter of a transpose continuous convolution.
26 /// All pointer arguments point to device memory unless stated otherwise.
28 /// \tparam TFeat Type for the features and weights
29 /// \tparam TOut Type for the output features
30 /// \tparam TReal Type for point positions and extents
31 /// \tparam TIndex Type for neighbor indexing
33 /// \param temp Pointer to temporary memory. If nullptr then the required
34 /// size of temporary memory will be written to \p temp_size and no
35 /// work is done. This function can make use of more memory and
36 /// returns the maximum size that can be used in max_temp_size.
38 /// \param temp_size The size of the temporary memory in bytes. This is
39 /// used as an output if temp is nullptr and returns the minimum temp
42 /// \param max_temp_size This is used as an output if temp is nullptr and
43 /// returns the maximum temp size that can be used.
45 /// \param texture_alignment The texture alignment in bytes. This is used
46 /// for allocating segments within the temporary memory.
48 /// \param filter_backrop Output array for the computed filter gradient
49 /// with shape [depth,height,witdth, inp channels, out channels]
51 /// \param filter_dims The sizes of the filter dimensions. The size of
52 /// filter_dims must be 5. The order is
53 /// [depth, height, width, inp channels, out channels].
55 /// \param filter Pointer to the filter values.
57 /// \param num_out The number of output points.
59 /// \param out_positions The positions of the output points. The shape is
62 /// \param out_importance Optional importance for each output point with
63 /// shape [num_out]. Set to null to disable.
65 /// \param num_inp The number of input points.
67 /// \param inp_positions The positions of the input points. The shape is
70 /// \param inp_features The input features with shape
71 /// [num_inp, in_channels].
73 /// \param inp_neighbors_importance_sum The sum of the neighbors_importance
74 /// values for each input with shape [num_inp].
76 /// \param inp_neighbors_row_splits The prefix sum which defines the start
77 /// and end of the sublists in \p inp_neighbors_index. The size of the
78 /// array is \p num_inp + 1.
80 /// \param neighbors_index_size The size of the neighbors_index array.
82 /// \param neighbors_index The array with lists of neighbors for each
83 /// output point. The start and end of each sublist is defined by
84 /// \p neighbors_row_splits.
86 /// \param neighbors_importance Optional importance for each entry in
87 /// \p neighbors_index. Set to null to disable.
89 /// \param neighbors_row_splits The prefix sum which defines the start
90 /// and end of the sublists in \p neighbors_index. The size of the
91 /// array is \p num_out + 1.
93 /// \param extents The spatial extents of the filter in coordinate units.
94 /// extents can be a scalar or a 1D array of shape [num_out] or a
95 /// 2D array of shape [num_out,3]. The shape depends on
96 /// \p individual_extent and \p isotropic_extent.
98 /// \param offsets A single 3D vector used in the filter coordinate
99 /// computation. The shape is [3].
101 /// \param out_features_gradient The gradient from the features with shape
102 /// [num_out, out_channels]
104 /// \param interpolation The interpolation mode. Either LINEAR or
105 /// NEAREST_NEIGHBOR.
107 /// \param coordinate_mapping The coordinate mapping function. One of
108 /// IDENTITY, BALL_TO_CUBE_RADIAL, BALL_TO_CUBE_VOLUME_PRESERVING.
110 /// \param align_corners If true then the voxel centers of the outer voxels
111 /// of the filter array are mapped to the boundary of the filter shape.
112 /// If false then the boundary of the filter array is mapped to the
113 /// boundary of the filter shape.
115 /// \param individual_extent If true each output point has an individual
118 /// \param isotropic_extent If true each then the extent is isotropic for
119 /// each output point.
121 /// \param normalize If true then the result is normalized either by the
122 /// number of points (neighbors_importance is null) or by the sum of
123 /// the respective values in neighbors_importance.
125 template <class TFeat, class TOut, class TReal, class TIndex>
126 void CConvTransposeBackpropFilterCUDA(const cudaStream_t& stream,
129 size_t& max_temp_size,
130 int texture_alignment,
131 TOut* filter_backprop,
132 const std::vector<int>& filter_dims,
134 const TReal* out_positions,
135 const TFeat* out_importance,
137 const TReal* inp_positions,
138 const TFeat* inp_features,
139 const TFeat* inp_neighbors_importance_sum,
140 const int64_t* inp_neighbors_row_splits,
141 size_t neighbors_index_size,
142 const TIndex* neighbors_index,
143 const TFeat* neighbors_importance,
144 const int64_t* neighbors_row_splits,
145 const TReal* extents,
146 const TReal* offsets,
147 const TFeat* out_features_gradient,
148 InterpolationMode interpolation,
149 CoordinateMapping coordinate_mapping,
151 bool individual_extent,
152 bool isotropic_extent,
154 const bool get_temp_size = !temp;
157 temp = (char*)1; // worst case alignment
158 temp_size = std::numeric_limits<int64_t>::max();
161 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
163 const int in_channels = filter_dims[filter_dims.size() - 2];
164 const int out_channels = filter_dims[filter_dims.size() - 1];
166 int spatial_filter_size = 1;
167 for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
169 // this defines how much temporary storage we need at least
170 // we want to allocate memory for at least 32 output points.
171 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
172 const size_t max_num_cols_per_run = num_out;
173 size_t bytes_per_column =
174 sizeof(TFeat) * (spatial_filter_size * in_channels);
175 if (out_importance) bytes_per_column += sizeof(TFeat) * out_channels;
176 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
177 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
180 std::pair<char*, size_t> tmp =
181 mem_temp.Alloc<char>(min_temp_size_bytes);
182 temp_size = mem_temp.MaxUsed();
184 mem_temp.Alloc<char>(max_temp_size_bytes);
185 max_temp_size = mem_temp.MaxUsed();
189 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
191 const size_t num_cols_per_run =
192 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
194 if (mem_columns.second < min_temp_size_bytes) {
195 std::stringstream ss;
196 ss << "temp is too small " << mem_columns.second
197 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
198 throw std::runtime_error(ss.str());
203 sizeof(TOut) * spatial_filter_size * in_channels * out_channels,
206 typedef cutlass::gemm::SgemmTraits<
207 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
208 cutlass::MatrixLayout::kRowMajor, // layout of B matrix
209 cutlass::Shape<8, 64, 64> // threadblock tile size
213 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
215 TFeat* columns = (TFeat*)mem_columns.first;
216 TFeat* gradient = ((TFeat*)mem_columns.first) +
217 num_cols_per_run * spatial_filter_size * in_channels;
219 // if we cannot process all data at once we need multiple runs
220 size_t num_runs = DivUp(num_out, num_cols_per_run);
221 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
222 const TIndex begin_idx = run_i * num_cols_per_run;
223 const TIndex end_idx =
224 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
225 const size_t num_cols_this_run = end_idx - begin_idx;
227 if (out_importance) {
228 MultiplyAndCopyColumns(
229 stream, out_channels, num_cols_this_run, gradient,
230 out_features_gradient +
231 (run_i * num_cols_per_run * out_channels),
232 out_importance + (run_i * num_cols_per_run));
234 gradient = const_cast<TFeat*>(
235 out_features_gradient +
236 (run_i * num_cols_per_run * out_channels));
239 FillColumnTranspose<TFeat, TReal, TIndex>(
240 stream, columns, in_channels, begin_idx, end_idx, num_out,
241 out_positions, num_inp, inp_positions, inp_features,
242 inp_neighbors_importance_sum, inp_neighbors_row_splits,
243 neighbors_index_size, neighbors_index, neighbors_importance,
244 neighbors_row_splits, extents, offsets, filter_dims,
245 interpolation, coordinate_mapping, align_corners,
246 individual_extent, isotropic_extent, normalize);
248 typename Gemm::Params params;
252 int m = out_channels;
253 int k = num_cols_this_run;
254 int n = spatial_filter_size * in_channels;
256 const float* const A = gradient;
258 const float* const B = columns;
261 float* C = filter_backprop;
265 params.initialize(m, // GEMM M dimension
266 n, // GEMM N dimension
267 k, // GEMM K dimension
268 alpha, // scalar alpha
269 A, // matrix A operand
271 B, // matrix B operand
274 C, // source matrix C
276 C, // destination matrix C (may be different
280 throw std::runtime_error(
281 "Failed to initialize CUTLASS Gemm::Params object.");
284 Gemm::launch(params, stream);
290 } // namespace cloudViewer