ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
ContinuousConvTranspose.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #define EIGEN_USE_GPU
10 
11 #include <Helper.h>
12 #include <cutlass/gemm/gemm.h>
13 #include <cutlass/gemm/sgemm_traits.h>
14 
15 #include "ml/impl/continuous_conv/ContinuousConvCUDAKernels.h"
16 #include "ml/impl/misc/MemoryAllocation.h"
17 
18 using cloudViewer::utility::DivUp;
19 
20 namespace cloudViewer {
21 namespace ml {
22 namespace impl {
23 
24 template <class TFeat, class TOut, class TReal, class TIndex>
25 void CConvTransposeComputeFeaturesCUDA(
26  const cudaStream_t& stream,
27  void* temp,
28  size_t& temp_size,
29  size_t& max_temp_size,
30  int texture_alignment,
31  TOut* out_features,
32  const std::vector<int>& filter_dims,
33  const TFeat* filter,
34  TIndex num_out,
35  const TReal* out_positions,
36  const TFeat* out_importance,
37  TIndex num_inp,
38  const TReal* inp_positions,
39  const TFeat* inp_features,
40  const TFeat* inp_neighbors_importance_sum,
41  const int64_t* inp_neighbors_prefix_sum,
42  size_t neighbors_index_size,
43  const TIndex* neighbors_index,
44  const TFeat* neighbors_importance,
45  const int64_t* neighbors_row_splits,
46  const TReal* extents,
47  const TReal* offsets,
48  InterpolationMode interpolation,
49  CoordinateMapping coordinate_mapping,
50  bool align_corners,
51  bool individual_extent,
52  bool isotropic_extent,
53  bool normalize) {
54  const bool get_temp_size = !temp;
55 
56  if (get_temp_size) {
57  temp = (char*)1; // worst case alignment
58  temp_size = std::numeric_limits<int64_t>::max();
59  }
60 
61  MemoryAllocation mem_temp(temp, temp_size, texture_alignment);
62 
63  const int in_channels = filter_dims[filter_dims.size() - 2];
64  const int out_channels = filter_dims[filter_dims.size() - 1];
65 
66  int spatial_filter_size = 1;
67  for (int i = 0; i < 3; ++i) spatial_filter_size *= filter_dims[i];
68 
69  // this defines how much temporary storage we need at least.
70  // we want to allocate memory for at least 32 output points.
71  const size_t min_num_cols_per_run = std::min(size_t(num_out), size_t(32));
72  const size_t max_num_cols_per_run = num_out;
73  const size_t bytes_per_column =
74  sizeof(TFeat) * (spatial_filter_size * in_channels);
75  const size_t min_temp_size_bytes = min_num_cols_per_run * bytes_per_column;
76  const size_t max_temp_size_bytes = max_num_cols_per_run * bytes_per_column;
77 
78  if (get_temp_size) {
79  std::pair<char*, size_t> tmp =
80  mem_temp.Alloc<char>(min_temp_size_bytes);
81  temp_size = mem_temp.MaxUsed();
82  mem_temp.Free(tmp);
83  mem_temp.Alloc<char>(max_temp_size_bytes);
84  max_temp_size = mem_temp.MaxUsed();
85  return;
86  }
87 
88  // Request segment using all of the temporary memory
89  std::pair<void*, size_t> mem_columns = mem_temp.AllocLargestSegment();
90 
91  if (mem_columns.second < min_temp_size_bytes) {
92  std::stringstream ss;
93  ss << "temp is too small " << mem_columns.second
94  << " bytes. Expected at least " << min_temp_size_bytes << " bytes\n";
95  throw std::runtime_error(ss.str());
96  }
97 
98  // init output
99  cudaMemsetAsync(out_features, 0, sizeof(TOut) * num_out * out_channels,
100  stream);
101 
102  size_t num_cols_per_run =
103  std::min(mem_columns.second / bytes_per_column, size_t(num_out));
104 
105  typedef cutlass::gemm::SgemmTraits<
106  cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
107  cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
108  cutlass::Shape<8, 64, 64> // threadblock tile size
109  >
110  GemmTraits;
111 
112  typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
113 
114  TFeat* columns = (TFeat*)mem_columns.first;
115 
116  // if we cannot process all data at once we need multiple runs
117  size_t num_runs = DivUp(num_out, num_cols_per_run);
118  for (size_t run_i = 0; run_i < num_runs; ++run_i) {
119  const TIndex begin_idx = run_i * num_cols_per_run;
120  const TIndex end_idx =
121  std::min(size_t(num_out), (run_i + 1) * num_cols_per_run);
122  const size_t num_cols_this_run = end_idx - begin_idx;
123 
124  FillColumnTranspose<TFeat, TReal, TIndex>(
125  stream, columns, in_channels, begin_idx, end_idx, num_out,
126  out_positions, num_inp, inp_positions, inp_features,
127  inp_neighbors_importance_sum, inp_neighbors_prefix_sum,
128  neighbors_index_size, neighbors_index, neighbors_importance,
129  neighbors_row_splits, extents, offsets, filter_dims,
130  interpolation, coordinate_mapping, align_corners,
131  individual_extent, isotropic_extent, normalize);
132 
133  typename Gemm::Params params;
134  // C is MxN
135  // B is KxN
136  // A is MxK
137  int m = out_channels;
138  int k = spatial_filter_size * in_channels;
139  int n = num_cols_this_run;
140  float alpha = 1;
141  const float* const A = filter;
142  int lda = m;
143  const float* const B = columns;
144  int ldb = k;
145  float beta = 1;
146  float* C = out_features + (run_i * num_cols_per_run * out_channels);
147  int ldc = m;
148 
149  int result =
150  params.initialize(m, // GEMM M dimension
151  n, // GEMM N dimension
152  k, // GEMM K dimension
153  alpha, // scalar alpha
154  A, // matrix A operand
155  lda,
156  B, // matrix B operand
157  ldb,
158  beta, // scalar beta
159  C, // source matrix C
160  ldc,
161  C, // destination matrix C (may be different
162  // memory than source C matrix)
163  ldc);
164 
165  if (result) {
166  throw std::runtime_error(
167  "Failed to initialize CUTLASS Gemm::Params object.");
168  }
169 
170  Gemm::launch(params, stream);
171  }
172 
173  if (out_importance) {
174  MultiplyColumns(stream, out_channels, num_out, out_features,
175  out_importance);
176  }
177 }
178 
179 } // namespace impl
180 } // namespace ml
181 } // namespace cloudViewer