1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
15 #include "ml/impl/misc/MemoryAllocation.h"
16 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
18 using cloudViewer::utility::DivUp;
20 namespace cloudViewer {
24 /// Computes the backprop for the filter of a sparse convolution.
26 /// All pointer arguments point to device memory unless stated otherwise.
28 /// \param temp Pointer to temporary memory. If nullptr then the required
29 /// size of temporary memory will be written to \p temp_size and no
30 /// work is done. This function can make use of more memory and
31 /// returns the maximum size that can be used in max_temp_size.
33 /// \param temp_size The size of the temporary memory in bytes. This is
34 /// used as an output if temp is nullptr and returns the minimum temp
37 /// \param max_temp_size This is used as an output if temp is nullptr and
38 /// returns the maximum temp size that can be used.
40 /// \param texture_alignment The texture alignment in bytes. This is used
41 /// for allocating segments within the temporary memory.
43 /// \param filter_backrop Output array for the computed filter gradient
44 /// with shape [depth,height,witdth, inp channels, out channels]
46 /// \param filter_dims The sizes of the filter dimensions. The size of
47 /// filter_dims must be >=3. The order is
48 /// [num kernel elements, inp channels, out channels].
50 /// \param num_out The number of output points.
52 /// \param num_inp The number of input points.
54 /// \param inp_features The input features with shape
55 /// [num_inp, in_channels].
57 /// \param inp_importance Optional importance for each input point with
58 /// shape [num_inp]. Set to null to disable.
60 /// \param neighbors_index_size The size of the neighbors_index array.
62 /// \param neighbors_index The array with lists of neighbors for each
63 /// output point. The start and end of each sublist is defined by
64 /// \p neighbors_row_splits.
66 /// \param neighbors_kernel_index Defines which kernel element to use for
67 /// each neighbor. This array has the same length as \p neighbors_index.
69 /// \param neighbors_importance Optional importance for each entry in
70 /// \p neighbors_index. Set to null to disable.
72 /// \param neighbors_row_splits The prefix sum which defines the start
73 /// and end of the sublists in \p neighbors_index. The size of the
74 /// array is \p num_out + 1.
76 /// \param normalize If true then the output features are normalized either
77 /// by the number of points (neighbors_importance is null) or by the sum
78 /// of the respective values in neighbors_importance.
80 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
81 void SparseConvBackpropFilterCUDA(const cudaStream_t& stream,
84 size_t& max_temp_size,
85 int texture_alignment,
86 TOut* filter_backprop,
87 const std::vector<int>& filter_dims,
90 const TFeat* inp_features,
91 const TFeat* inp_importance,
92 size_t neighbors_index_size,
93 const TIndex* neighbors_index,
94 const TKernelIndex* neighbors_kernel_index,
95 const TFeat* neighbors_importance,
96 const int64_t* neighbors_row_splits,
97 const TFeat* out_features_gradient,
99 const bool get_temp_size = !temp;
102 temp = (char*)1; // worst case alignment
103 temp_size = std::numeric_limits<int64_t>::max();
106 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
108 const int in_channels = filter_dims[filter_dims.size() - 2];
109 const int out_channels = filter_dims[filter_dims.size() - 1];
111 int num_kernel_elements = 1;
112 for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
113 num_kernel_elements *= filter_dims[i];
115 // this defines how much temporary storage we need at least
116 // we want to allocate memory for at least 32 output points.
117 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
118 const size_t max_num_cols_per_run = num_out;
119 const size_t bytes_per_column =
120 sizeof(TFeat) * (num_kernel_elements * in_channels);
121 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
122 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
125 std::pair<char*, size_t> tmp =
126 mem_temp.Alloc<char>(min_temp_size_bytes);
127 temp_size = mem_temp.MaxUsed();
129 mem_temp.Alloc<char>(max_temp_size_bytes);
130 max_temp_size = mem_temp.MaxUsed();
134 // Request segment using all of the temporary memory
135 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
137 if (mem_columns.second < min_temp_size_bytes) {
138 std::stringstream ss;
139 ss << "temp is too small " << mem_columns.second
140 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
141 throw std::runtime_error(ss.str());
147 sizeof(TOut) * num_kernel_elements * in_channels * out_channels,
150 size_t num_cols_per_run =
151 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
153 typedef cutlass::gemm::SgemmTraits<
154 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
155 cutlass::MatrixLayout::kRowMajor, // layout of B matrix
156 cutlass::Shape<8, 64, 64> // threadblock tile size
160 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
162 TFeat* columns = (TFeat*)mem_columns.first;
164 // if we cannot process all data at once we need multiple runs
165 size_t num_runs = DivUp(num_out, num_cols_per_run);
166 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
167 const TIndex begin_idx = run_i * num_cols_per_run;
168 const TIndex end_idx =
169 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
170 const size_t num_cols_this_run = end_idx - begin_idx;
172 FillColumn<TFeat, TIndex, TKernelIndex>(
173 stream, columns, in_channels, begin_idx, end_idx, num_out,
174 num_inp, inp_features, inp_importance, neighbors_index_size,
175 neighbors_index, neighbors_kernel_index, neighbors_importance,
176 neighbors_row_splits, num_kernel_elements, normalize);
178 typename Gemm::Params params;
182 int m = out_channels;
183 int k = num_cols_this_run;
184 int n = num_kernel_elements * in_channels;
186 const float* const A = out_features_gradient +
187 (run_i * num_cols_per_run * out_channels);
189 const float* const B = columns;
192 float* C = filter_backprop;
195 int result = params.initialize(m, // GEMM M dimension
196 n, // GEMM N dimension
197 k, // GEMM K dimension
198 alpha, // scalar alpha
199 A, // matrix A operand
201 B, // matrix B operand
204 C, // source matrix C
206 C, // destination matrix C
210 throw std::runtime_error(
211 "Failed to initialize CUTLASS Gemm::Params object.");
214 Gemm::launch(params, stream);
220 } // namespace cloudViewer