1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 #include "ml/impl/sparse_conv/SparseConvCUDAKernels.h"
19 using cloudViewer::utility::DivUp;
21 namespace cloudViewer {
25 template <class TFeat, class TOut, class TIndex, class TKernelIndex>
26 void SparseConvTransposeComputeFeaturesCUDA(
27 const cudaStream_t& stream,
30 size_t& max_temp_size,
31 int texture_alignment,
33 const std::vector<int>& filter_dims,
36 const TFeat* out_importance,
38 const TFeat* inp_features,
39 const TFeat* inp_neighbors_importance_sum,
40 const int64_t* inp_neighbors_prefix_sum,
41 size_t neighbors_index_size,
42 const TIndex* neighbors_index,
43 const TKernelIndex* neighbors_kernel_index,
44 const TFeat* neighbors_importance,
45 const int64_t* neighbors_row_splits,
47 const bool get_temp_size = !temp;
50 temp = (char*)1; // worst case alignment
51 temp_size = std::numeric_limits<int64_t>::max();
54 MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
56 const int in_channels = filter_dims[filter_dims.size() - 2];
57 const int out_channels = filter_dims[filter_dims.size() - 1];
59 int num_kernel_elements = 1;
60 for (std::size_t i = 0; i < filter_dims.size() - 2; ++i)
61 num_kernel_elements *= filter_dims[i];
63 // this defines how much temporary storage we need at least.
64 // we want to allocate memory for at least 32 output points.
65 const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
66 const size_t max_num_cols_per_run = num_out;
67 const size_t bytes_per_column =
68 sizeof(TFeat) * (num_kernel_elements * in_channels);
69 const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
70 const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
73 std::pair<char*, size_t> tmp =
74 mem_temp.Alloc<char>(min_temp_size_bytes);
75 temp_size = mem_temp.MaxUsed();
77 mem_temp.Alloc<char>(max_temp_size_bytes);
78 max_temp_size = mem_temp.MaxUsed();
82 // Request segment using all of the temporary memory
83 std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
85 if (mem_columns.second < min_temp_size_bytes) {
87 ss << "temp is too small " << mem_columns.second
88 << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
89 throw std::runtime_error(ss.str());
93 cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
96 size_t num_cols_per_run =
97 std::min(mem_columns.second / bytes_per_column, size_t(num_out));
99 typedef cutlass::gemm::SgemmTraits<
100 cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
101 cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
102 cutlass::Shape<8, 64, 64> // threadblock tile size
106 typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
108 TFeat* columns = (TFeat*)mem_columns.first;
110 // if we cannot process all data at once we need multiple runs
111 size_t num_runs = DivUp(num_out, num_cols_per_run);
112 for (size_t run_i = 0; run_i < num_runs; ++run_i) {
113 const TIndex begin_idx = run_i * num_cols_per_run;
114 const TIndex end_idx =
115 std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
116 const size_t num_cols_this_run = end_idx - begin_idx;
118 FillColumnTranspose<TFeat, TIndex, TKernelIndex>(
119 stream, columns, in_channels, begin_idx, end_idx, num_out,
120 num_inp, inp_features, inp_neighbors_importance_sum,
121 inp_neighbors_prefix_sum, neighbors_index_size, neighbors_index,
122 neighbors_kernel_index, neighbors_importance,
123 neighbors_row_splits, num_kernel_elements, normalize);
125 typename Gemm::Params params;
129 int m = out_channels;
130 int k = num_kernel_elements * in_channels;
131 int n = num_cols_this_run;
133 const float* const A = filter;
135 const float* const B = columns;
138 float* C = out_features + (run_i * num_cols_per_run * out_channels);
142 params.initialize(m, // GEMM M dimension
143 n, // GEMM N dimension
144 k, // GEMM K dimension
145 alpha, // scalar alpha
146 A, // matrix A operand
148 B, // matrix B operand
151 C, // source matrix C
153 C, // destination matrix C (may be different
154 // memory than source C matrix)
158 throw std::runtime_error(
159 "Failed to initialize CUTLASS Gemm::Params object.");
162 Gemm::launch(params, stream);
165 if (out_importance) {
166 MultiplyColumns(stream, out_channels, num_out, out_features,
173 } // namespace cloudViewer