1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
13 #include <cutlass/gemm/gemm.h>
14 #include <cutlass/gemm/sgemm_traits.h>
16 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
17 #include "ml/impl/misc/MemoryAllocation.h"
19 using cloudViewer::utility::DivUp;
21 namespace cloudViewer {
25 /// Computes the output features of a continuous convolution.
27 /// All pointer arguments point to device memory unless stated otherwise.
29 /// \tparam TFeat Type for the features and weights
30 /// \tparam TOut Type for the output features
31 /// \tparam TReal Type for point positions and extents
32 /// \tparam TIndex Type for neighbor indexing
34 /// \param temp Pointer to temporary memory. If nullptr then the required
35 /// size of temporary memory will be written to \p temp_size and no
36 /// work is done. This function can make use of more memory and
37 /// returns the maximum size that can be used in max_temp_size.
39 /// \param temp_size The size of the temporary memory in bytes. This is
40 /// used as an output if temp is nullptr and returns the minimum temp
43 /// \param max_temp_size This is used as an output if temp is nullptr and
44 /// returns the maximum temp size that can be used.
46 /// \param texture_alignment The texture alignment in bytes. This is used
47 /// for allocating segments within the temporary memory.
49 /// \param out_features Output array for the computed features with shape
50 /// [num_out, out channels]
52 /// \param filter_dims The sizes of the filter dimensions. The size of
53 /// filter_dims must be 5. The order is
54 /// [depth, height, width, inp channels, out channels].
56 /// \param filter Pointer to the filter values.
58 /// \param num_out The number of output points.
60 /// \param out_positions The positions of the output points. The shape is
63 /// \param num_inp The number of input points.
65 /// \param inp_positions The positions of the input points. The shape is
68 /// \param inp_features The input features with shape
69 /// [num_inp, in_channels].
71 /// \param inp_importance Optional importance for each input point with
72 /// shape [num_inp]. Set to null to disable.
74 /// \param neighbors_index_size The size of the neighbors_index array.
76 /// \param neighbors_index The array with lists of neighbors for each
77 /// output point. The start and end of each sublist is defined by
78 /// \p neighbors_row_splits.
80 /// \param neighbors_importance Optional importance for each entry in
81 /// \p neighbors_index. Set to null to disable.
83 /// \param neighbors_row_splits The prefix sum which defines the start
84 /// and end of the sublists in \p neighbors_index. The size of the
85 /// array is \p num_out + 1.
87 /// \param extents The spatial extents of the filter in coordinate units.
88 /// extents can be a scalar or a 1D array of shape [num_out] or a
89 /// 2D array of shape [num_out,3]. The shape depends on
90 /// \p individual_extent and \p isotropic_extent.
92 /// \param offsets A single 3D vector used in the filter coordinate
93 /// computation. The shape is [3].
95 /// \param interpolation The interpolation mode. Either LINEAR or
98 /// \param coordinate_mapping The coordinate mapping function. One of
99 /// IDENTITY, BALL_TO_CUBE_RADIAL, BALL_TO_CUBE_VOLUME_PRESERVING.
101 /// \param align_corners If true then the voxel centers of the outer voxels
102 /// of the filter array are mapped to the boundary of the filter shape.
103 /// If false then the boundary of the filter array is mapped to the
104 /// boundary of the filter shape.
106 /// \param individual_extent If true each output point has an individual
109 /// \param isotropic_extent If true each then the extent is isotropic for
110 /// each output point.
112 /// \param normalize If true then the result is normalized either by the
113 /// number of points (neighbors_importance is null) or by the sum of
114 /// the respective values in neighbors_importance.
116 template <class TFeat, class TOut, class TReal, class TIndex>
117 void CConvComputeFeaturesCUDA(const cudaStream_t& stream,
120 size_t& max_temp_size,
121 int texture_alignment,
123 const std::vector<int>& filter_dims,
126 const TReal* out_positions,
128 const TReal* inp_positions,
129 const TFeat* inp_features,
130 const TFeat* inp_importance,
131 size_t neighbors_index_size,
132 const TIndex* neighbors_index,
133 const TFeat* neighbors_importance,
134 const int64_t* neighbors_row_splits,
135 const TReal* extents,
136 const TReal* offsets,
137 InterpolationMode interpolation,
138 CoordinateMapping coordinate_mapping,
140 bool individual_extent,
141 bool isotropic_extent,
143 const bool get_temp_size = !temp;
146 temp = (char*)1; // worst case alignment
147 temp_size = std::numeric_limits<int64_t>::max();
150 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
152 const int in_channels = filter_dims[filter_dims.size() - 2];
153 const int out_channels = filter_dims[filter_dims.size() - 1];
155 int spatial_filter_size = 1;
156 for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
158 // this defines how much temporary storage we need at least.
159 // we want to allocate memory for at least 32 output points.
160 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
161 const size_t max_num_cols_per_run = num_out;
162 const size_t bytes_per_column =
163 sizeof(TFeat) * (spatial_filter_size * in_channels);
164 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
165 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
168 std::pair<char*, size_t> tmp =
169 mem_temp.Alloc<char>(min_temp_size_bytes);
170 temp_size = mem_temp.MaxUsed();
172 mem_temp.Alloc<char>(max_temp_size_bytes);
173 max_temp_size = mem_temp.MaxUsed();
177 // Request segment using all of the temporary memory
178 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
180 if (mem_columns.second < min_temp_size_bytes) {
181 std::stringstream ss;
182 ss << "temp is too small " << mem_columns.second
183 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
184 throw std::runtime_error(ss.str());
188 cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
191 size_t num_cols_per_run =
192 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
194 typedef cutlass::gemm::SgemmTraits<
195 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix (filter)
196 cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
198 cutlass::Shape<8, 64, 64> // threadblock tile size
202 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
204 // this is the pointer to the patch matrix
205 TFeat* columns = (TFeat*)mem_columns.first;
207 // if we cannot process all data at once we need multiple runs
208 const size_t num_runs = DivUp(num_out, num_cols_per_run);
209 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
210 const TIndex begin_idx = run_i * num_cols_per_run;
211 const TIndex end_idx =
212 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
213 const size_t num_cols_this_run = end_idx - begin_idx;
215 // compute the patch matrix
216 FillColumn<TFeat, TReal, TIndex>(
217 stream, columns, in_channels, begin_idx, end_idx, num_out,
218 out_positions, num_inp, inp_positions, inp_features,
219 inp_importance, neighbors_index_size, neighbors_index,
220 neighbors_importance, neighbors_row_splits, extents, offsets,
221 filter_dims, interpolation, coordinate_mapping, align_corners,
222 individual_extent, isotropic_extent, normalize);
227 int m = out_channels;
228 int k = spatial_filter_size * in_channels;
229 int n = num_cols_this_run;
231 const float* const A = filter;
233 const float* const B = columns;
236 float* C = out_features + (run_i * num_cols_per_run * out_channels);
239 typename Gemm::Params params;
240 int result = params.initialize(m, // GEMM M dimension
241 n, // GEMM N dimension
242 k, // GEMM K dimension
243 alpha, // scalar alpha
244 A, // matrix A operand
246 B, // matrix B operand
249 C, // source matrix C
251 C, // destination matrix C
255 throw std::runtime_error(
256 "Failed to initialize CUTLASS Gemm::Params object.");
259 Gemm::launch(params, stream);
265 } // namespace cloudViewer