1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
13 #include <cutlass/gemm/gemm.h>
14 #include <cutlass/gemm/sgemm_traits.h>
16 #include "ml/impl/misc/MemoryAllocation.h"
17 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
19 using cloudViewer::utility::DivUp;
21 namespace cloudViewer {
25 /// Computes the output features of a sparse convolution.
27 /// All pointer arguments point to device memory unless stated otherwise.
29 /// \param temp Pointer to temporary memory. If nullptr then the required
30 /// size of temporary memory will be written to \p temp_size and no
31 /// work is done. This function can make use of more memory and
32 /// returns the maximum size that can be used in max_temp_size.
34 /// \param temp_size The size of the temporary memory in bytes. This is
35 /// used as an output if temp is nullptr and returns the minimum temp
38 /// \param max_temp_size This is used as an output if temp is nullptr and
39 /// returns the maximum temp size that can be used.
41 /// \param texture_alignment The texture alignment in bytes. This is used
42 /// for allocating segments within the temporary memory.
44 /// \param out_features Output array for the computed features with shape
45 /// [num_out, out channels]
47 /// \param filter_dims The sizes of the filter dimensions. The size of
48 /// filter_dims must be >=3. The order is
49 /// [num kernel elements, inp channels, out channels].
51 /// \param filter Pointer to the filter values.
53 /// \param num_out The number of output points.
55 /// \param num_inp The number of input points.
57 /// \param inp_features The input features with shape
58 /// [num_inp, in_channels].
60 /// \param inp_importance Optional importance for each input point with
61 /// shape [num_inp]. Set to null to disable.
63 /// \param neighbors_index_size The size of the neighbors_index array.
65 /// \param neighbors_index The array with lists of neighbors for each
66 /// output point. The start and end of each sublist is defined by
67 /// \p neighbors_row_splits.
69 /// \param neighbors_kernel_index Defines which kernel element to use for
70 /// each neighbor. This array has the same length as \p neighbors_index.
72 /// \param neighbors_importance Optional importance for each entry in
73 /// \p neighbors_index. Set to null to disable.
75 /// \param neighbors_row_splits The prefix sum which defines the start
76 /// and end of the sublists in \p neighbors_index. The size of the
77 /// array is \p num_out + 1.
79 /// \param normalize If true then the result is normalized either by the
80 /// number of points (neighbors_importance is null) or by the sum of
81 /// the respective values in neighbors_importance.
83 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
84 void SparseConvComputeFeaturesCUDA(const cudaStream_t& stream,
87 size_t& max_temp_size,
88 int texture_alignment,
90 const std::vector<int>& filter_dims,
94 const TFeat* inp_features,
95 const TFeat* inp_importance,
96 size_t neighbors_index_size,
97 const TIndex* neighbors_index,
98 const TKernelIndex* neighbors_kernel_index,
99 const TFeat* neighbors_importance,
100 const int64_t* neighbors_row_splits,
102 const bool get_temp_size = !temp;
105 temp = (char*)1; // worst case alignment
106 temp_size = std::numeric_limits<int64_t>::max();
109 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
111 const int in_channels = filter_dims[filter_dims.size() - 2];
112 const int out_channels = filter_dims[filter_dims.size() - 1];
114 int num_kernel_elements = 1;
115 for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
116 num_kernel_elements *= filter_dims[i];
118 // this defines how much temporary storage we need at least.
119 // we want to allocate memory for at least 32 output points.
120 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
121 const size_t max_num_cols_per_run = num_out;
122 const size_t bytes_per_column =
123 sizeof(TFeat) * (num_kernel_elements * in_channels);
124 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
125 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
128 std::pair<char*, size_t> tmp =
129 mem_temp.Alloc<char>(min_temp_size_bytes);
130 temp_size = mem_temp.MaxUsed();
132 mem_temp.Alloc<char>(max_temp_size_bytes);
133 max_temp_size = mem_temp.MaxUsed();
137 // Request segment using all of the temporary memory
138 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
140 if (mem_columns.second < min_temp_size_bytes) {
141 std::stringstream ss;
142 ss << "temp is too small " << mem_columns.second
143 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
144 throw std::runtime_error(ss.str());
148 cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
151 size_t num_cols_per_run =
152 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
154 typedef cutlass::gemm::SgemmTraits<
155 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix (filter)
156 cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
158 cutlass::Shape<8, 64, 64> // threadblock tile size
162 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
164 // this is the pointer to the patch matrix
165 TFeat* columns = (TFeat*)mem_columns.first;
167 // if we cannot process all data at once we need multiple runs
168 const size_t num_runs = DivUp(num_out, num_cols_per_run);
169 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
170 const TIndex begin_idx = run_i * num_cols_per_run;
171 const TIndex end_idx =
172 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
173 const size_t num_cols_this_run = end_idx - begin_idx;
175 // compute the patch matrix
176 FillColumn<TFeat, TIndex>(
177 stream, columns, in_channels, begin_idx, end_idx, num_out,
178 num_inp, inp_features, inp_importance, neighbors_index_size,
179 neighbors_index, neighbors_kernel_index, neighbors_importance,
180 neighbors_row_splits, num_kernel_elements, normalize);
185 int m = out_channels;
186 int k = num_kernel_elements * in_channels;
187 int n = num_cols_this_run;
189 const float* const A = filter;
191 const float* const B = columns;
194 float* C = out_features + (run_i * num_cols_per_run * out_channels);
197 typename Gemm::Params params;
198 int result = params.initialize(m, // GEMM M dimension
199 n, // GEMM N dimension
200 k, // GEMM K dimension
201 alpha, // scalar alpha
202 A, // matrix A operand
204 B, // matrix B operand
207 C, // source matrix C
209 C, // destination matrix C
213 throw std::runtime_error(
214 "Failed to initialize CUTLASS Gemm::Params object.");
217 Gemm::launch(params, stream);
223 } // namespace cloudViewer