1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
18 using cloudViewer::utility::DivUp;
20 namespace cloudViewer {
24 template <class TFeat, class TOut, class TReal, class TIndex>
25 void CConvTransposeComputeFeaturesCUDA(
26 const cudaStream_t& stream,
29 size_t& max_temp_size,
30 int texture_alignment,
32 const std::vector<int>& filter_dims,
35 const TReal* out_positions,
36 const TFeat* out_importance,
38 const TReal* inp_positions,
39 const TFeat* inp_features,
40 const TFeat* inp_neighbors_importance_sum,
41 const int64_t* inp_neighbors_prefix_sum,
42 size_t neighbors_index_size,
43 const TIndex* neighbors_index,
44 const TFeat* neighbors_importance,
45 const int64_t* neighbors_row_splits,
48 InterpolationMode interpolation,
49 CoordinateMapping coordinate_mapping,
51 bool individual_extent,
52 bool isotropic_extent,
54 const bool get_temp_size = !temp;
57 temp = (char*)1; // worst case alignment
58 temp_size = std::numeric_limits<int64_t>::max();
61 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
63 const int in_channels = filter_dims[filter_dims.size() - 2];
64 const int out_channels = filter_dims[filter_dims.size() - 1];
66 int spatial_filter_size = 1;
67 for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
69 // this defines how much temporary storage we need at least.
70 // we want to allocate memory for at least 32 output points.
71 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
72 const size_t max_num_cols_per_run = num_out;
73 const size_t bytes_per_column =
74 sizeof(TFeat) * (spatial_filter_size * in_channels);
75 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
76 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
79 std::pair<char*, size_t> tmp =
80 mem_temp.Alloc<char>(min_temp_size_bytes);
81 temp_size = mem_temp.MaxUsed();
83 mem_temp.Alloc<char>(max_temp_size_bytes);
84 max_temp_size = mem_temp.MaxUsed();
88 // Request segment using all of the temporary memory
89 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
91 if (mem_columns.second < min_temp_size_bytes) {
93 ss << "temp is too small " << mem_columns.second
94 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
95 throw std::runtime_error(ss.str());
99 cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
102 size_t num_cols_per_run =
103 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
105 typedef cutlass::gemm::SgemmTraits<
106 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
107 cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
108 cutlass::Shape<8, 64, 64> // threadblock tile size
112 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
114 TFeat* columns = (TFeat*)mem_columns.first;
116 // if we cannot process all data at once we need multiple runs
117 size_t num_runs = DivUp(num_out, num_cols_per_run);
118 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
119 const TIndex begin_idx = run_i * num_cols_per_run;
120 const TIndex end_idx =
121 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
122 const size_t num_cols_this_run = end_idx - begin_idx;
124 FillColumnTranspose<TFeat, TReal, TIndex>(
125 stream, columns, in_channels, begin_idx, end_idx, num_out,
126 out_positions, num_inp, inp_positions, inp_features,
127 inp_neighbors_importance_sum, inp_neighbors_prefix_sum,
128 neighbors_index_size, neighbors_index, neighbors_importance,
129 neighbors_row_splits, extents, offsets, filter_dims,
130 interpolation, coordinate_mapping, align_corners,
131 individual_extent, isotropic_extent, normalize);
133 typename Gemm::Params params;
137 int m = out_channels;
138 int k = spatial_filter_size * in_channels;
139 int n = num_cols_this_run;
141 const float* const A = filter;
143 const float* const B = columns;
146 float* C = out_features + (run_i * num_cols_per_run * out_channels);
150 params.initialize(m, // GEMM M dimension
151 n, // GEMM N dimension
152 k, // GEMM K dimension
153 alpha, // scalar alpha
154 A, // matrix A operand
156 B, // matrix B operand
159 C, // source matrix C
161 C, // destination matrix C (may be different
162 // memory than source C matrix)
166 throw std::runtime_error(
167 "Failed to initialize CUTLASS Gemm::Params object.");
170 Gemm::launch(params, stream);
173 if (out_importance) {
174 MultiplyColumns(stream, out_channels, num_out, out_features,
181 } // namespace cloudViewer