1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
19 using cloudViewer::utility::DivUp;
21 namespace cloudViewer {
25 /// Computes the backprop for the filter of a transpose sparse convolution.
27 /// All pointer arguments point to device memory unless stated otherwise.
29 /// \param temp Pointer to temporary memory. If nullptr then the required
30 /// size of temporary memory will be written to \p temp_size and no
31 /// work is done. This function can make use of more memory and
32 /// returns the maximum size that can be used in max_temp_size.
34 /// \param temp_size The size of the temporary memory in bytes. This is
35 /// used as an output if temp is nullptr and returns the minimum temp
38 /// \param max_temp_size This is used as an output if temp is nullptr and
39 /// returns the maximum temp size that can be used.
41 /// \param texture_alignment The texture alignment in bytes. This is used
42 /// for allocating segments within the temporary memory.
44 /// \param filter_backrop Output array for the computed filter gradient
45 /// with shape [depth,height,witdth, inp channels, out channels]
47 /// \param filter_dims The sizes of the filter dimensions. The size of
48 /// filter_dims must be >=3. The order is
49 /// [num kernel elements, inp channels, out channels].
51 /// \param filter Pointer to the filter values.
53 /// \param num_out The number of output points.
55 /// \param out_importance Optional importance for each output point with
56 /// shape [num_out]. Set to null to disable.
58 /// \param num_inp The number of input points.
60 /// \param inp_features The input features with shape
61 /// [num_inp, in_channels].
63 /// \param inp_neighbors_importance_sum The sum of the neighbors_importance
64 /// values for each input with shape [num_inp].
66 /// \param inp_neighbors_row_splits The prefix sum which defines the start
67 /// and end of the sublists in \p inp_neighbors_index. The size of the
68 /// array is \p num_inp + 1.
70 /// \param neighbors_index_size The size of the neighbors_index array.
72 /// \param neighbors_index The array with lists of neighbors for each
73 /// output point. The start and end of each sublist is defined by
74 /// \p neighbors_row_splits.
76 /// \param neighbors_kernel_index Defines which kernel element to use for
77 /// each neighbor. This array has the same length as \p neighbors_index.
79 /// \param neighbors_importance Optional importance for each entry in
80 /// \p neighbors_index. Set to null to disable.
82 /// \param neighbors_row_splits The prefix sum which defines the start
83 /// and end of the sublists in \p neighbors_index. The size of the
84 /// array is \p num_out + 1.
86 /// \param out_features_gradient The gradient from the features with shape
87 /// [num_out, out_channels]
89 /// \param normalize If true then the result is normalized either by the
90 /// number of points (neighbors_importance is null) or by the sum of
91 /// the respective values in neighbors_importance.
93 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
94 void SparseConvTransposeBackpropFilterCUDA(
95 const cudaStream_t& stream,
98 size_t& max_temp_size,
99 int texture_alignment,
100 TOut* filter_backprop,
101 const std::vector<int>& filter_dims,
103 const TFeat* out_importance,
105 const TFeat* inp_features,
106 const TFeat* inp_neighbors_importance_sum,
107 const int64_t* inp_neighbors_row_splits,
108 size_t neighbors_index_size,
109 const TIndex* neighbors_index,
110 const TKernelIndex* neighbors_kernel_index,
111 const TFeat* neighbors_importance,
112 const int64_t* neighbors_row_splits,
113 const TFeat* out_features_gradient,
115 const bool get_temp_size = !temp;
118 temp = (char*)1; // worst case alignment
119 temp_size = std::numeric_limits<int64_t>::max();
122 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
124 const int in_channels = filter_dims[filter_dims.size() - 2];
125 const int out_channels = filter_dims[filter_dims.size() - 1];
127 int num_kernel_elements = 1;
128 for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
129 num_kernel_elements *= filter_dims[i];
131 // this defines how much temporary storage we need at least
132 // we want to allocate memory for at least 32 output points.
133 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
134 const size_t max_num_cols_per_run = num_out;
135 size_t bytes_per_column =
136 sizeof(TFeat) * (num_kernel_elements * in_channels);
137 if (out_importance) bytes_per_column += sizeof(TFeat) * out_channels;
138 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
139 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
142 std::pair<char*, size_t> tmp =
143 mem_temp.Alloc<char>(min_temp_size_bytes);
144 temp_size = mem_temp.MaxUsed();
146 mem_temp.Alloc<char>(max_temp_size_bytes);
147 max_temp_size = mem_temp.MaxUsed();
151 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
153 const size_t num_cols_per_run =
154 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
156 if (mem_columns.second < min_temp_size_bytes) {
157 std::stringstream ss;
158 ss << "temp is too small " << mem_columns.second
159 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
160 throw std::runtime_error(ss.str());
165 sizeof(TOut) * num_kernel_elements * in_channels * out_channels,
168 typedef cutlass::gemm::SgemmTraits<
169 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
170 cutlass::MatrixLayout::kRowMajor, // layout of B matrix
171 cutlass::Shape<8, 64, 64> // threadblock tile size
175 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
177 TFeat* columns = (TFeat*)mem_columns.first;
178 TFeat* gradient = ((TFeat*)mem_columns.first) +
179 num_cols_per_run * num_kernel_elements * in_channels;
181 // if we cannot process all data at once we need multiple runs
182 size_t num_runs = DivUp(num_out, num_cols_per_run);
183 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
184 const TIndex begin_idx = run_i * num_cols_per_run;
185 const TIndex end_idx =
186 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
187 const size_t num_cols_this_run = end_idx - begin_idx;
189 if (out_importance) {
190 MultiplyAndCopyColumns(
191 stream, out_channels, num_cols_this_run, gradient,
192 out_features_gradient +
193 (run_i * num_cols_per_run * out_channels),
194 out_importance + (run_i * num_cols_per_run));
196 gradient = const_cast<TFeat*>(
197 out_features_gradient +
198 (run_i * num_cols_per_run * out_channels));
201 FillColumnTranspose<TFeat, TIndex>(
202 stream, columns, in_channels, begin_idx, end_idx, num_out,
203 num_inp, inp_features, inp_neighbors_importance_sum,
204 inp_neighbors_row_splits, neighbors_index_size, neighbors_index,
205 neighbors_kernel_index, neighbors_importance,
206 neighbors_row_splits, num_kernel_elements, normalize);
208 typename Gemm::Params params;
212 int m = out_channels;
213 int k = num_cols_this_run;
214 int n = num_kernel_elements * in_channels;
216 const float* const A = gradient;
218 const float* const B = columns;
221 float* C = filter_backprop;
225 params.initialize(m, // GEMM M dimension
226 n, // GEMM N dimension
227 k, // GEMM K dimension
228 alpha, // scalar alpha
229 A, // matrix A operand
231 B, // matrix B operand
234 C, // source matrix C
236 C, // destination matrix C (may be different
240 throw std::runtime_error(
241 "Failed to initialize CUTLASS Gemm::Params object.");
244 Gemm::launch(params, stream);
250 } // namespace cloudViewer