ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
BlockMerge.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "core/nns/kernel/DeviceDefs.cuh"
11 #include "core/nns/kernel/StaticUtils.cuh"
12 
13 namespace cloudViewer {
14 namespace core {
15 
16 // Merge pairs of lists smaller than blockDim.x (NumThreads)
17 template <int NumThreads,
18  typename K,
19  typename V,
20  int N,
21  int L,
22  bool AllThreads,
23  bool Dir,
24  bool FullMerge>
25 inline __device__ void blockMergeSmall(K* listK, V* listV) {
26  static_assert(isPowerOf2(L), "L must be a power-of-2");
27  static_assert(isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
28  static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
29 
30  // Which pair of lists we are merging
31  int mergeId = threadIdx.x / L;
32 
33  // Which thread we are within the merge
34  int tid = threadIdx.x % L;
35 
36  // listK points to a region of size N * 2 * L
37  listK += 2 * L * mergeId;
38  listV += 2 * L * mergeId;
39 
40  // It's not a bitonic merge, both lists are in the same direction,
41  // so handle the first swap assuming the second list is reversed
42  int pos = L - 1 - tid;
43  int stride = 2 * tid + 1;
44 
45  if (AllThreads || (threadIdx.x < N * L)) {
46  K ka = listK[pos];
47  K kb = listK[pos + stride];
48 
49  // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
50  bool swap = Dir ? ka > kb : ka < kb;
51  listK[pos] = swap ? kb : ka;
52  listK[pos + stride] = swap ? ka : kb;
53 
54  V va = listV[pos];
55  V vb = listV[pos + stride];
56  listV[pos] = swap ? vb : va;
57  listV[pos + stride] = swap ? va : vb;
58 
59  // FIXME: is this a CUDA 9 compiler bug?
60  // K& ka = listK[pos];
61  // K& kb = listK[pos + stride];
62 
63  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
64  // swap(s, ka, kb);
65 
66  // V& va = listV[pos];
67  // V& vb = listV[pos + stride];
68  // swap(s, va, vb);
69  }
70 
71  __syncthreads();
72 
73 #pragma unroll
74  for (int stride = L / 2; stride > 0; stride /= 2) {
75  int pos = 2 * tid - (tid & (stride - 1));
76 
77  if (AllThreads || (threadIdx.x < N * L)) {
78  K ka = listK[pos];
79  K kb = listK[pos + stride];
80 
81  // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
82  bool swap = Dir ? ka > kb : ka < kb;
83  listK[pos] = swap ? kb : ka;
84  listK[pos + stride] = swap ? ka : kb;
85 
86  V va = listV[pos];
87  V vb = listV[pos + stride];
88  listV[pos] = swap ? vb : va;
89  listV[pos + stride] = swap ? va : vb;
90 
91  // FIXME: is this a CUDA 9 compiler bug?
92  // K& ka = listK[pos];
93  // K& kb = listK[pos + stride];
94 
95  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
96  // swap(s, ka, kb);
97 
98  // V& va = listV[pos];
99  // V& vb = listV[pos + stride];
100  // swap(s, va, vb);
101  }
102 
103  __syncthreads();
104  }
105 }
106 
107 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
108 template <int NumThreads,
109  typename K,
110  typename V,
111  int L,
112  bool Dir,
113  bool FullMerge>
114 inline __device__ void blockMergeLarge(K* listK, V* listV) {
115  static_assert(isPowerOf2(L), "L must be a power-of-2");
116  static_assert(L >= kWarpSize, "merge list size must be >= 32");
117  static_assert(isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
118  static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
119 
120  // For L > NumThreads, each thread has to perform more work
121  // per each stride.
122  constexpr int kLoopPerThread = L / NumThreads;
123 
124  // It's not a bitonic merge, both lists are in the same direction,
125  // so handle the first swap assuming the second list is reversed
126 #pragma unroll
127  for (int loop = 0; loop < kLoopPerThread; ++loop) {
128  int tid = loop * NumThreads + threadIdx.x;
129  int pos = L - 1 - tid;
130  int stride = 2 * tid + 1;
131 
132  K ka = listK[pos];
133  K kb = listK[pos + stride];
134 
135  // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
136  bool swap = Dir ? ka > kb : ka < kb;
137  listK[pos] = swap ? kb : ka;
138  listK[pos + stride] = swap ? ka : kb;
139 
140  V va = listV[pos];
141  V vb = listV[pos + stride];
142  listV[pos] = swap ? vb : va;
143  listV[pos + stride] = swap ? va : vb;
144 
145  // FIXME: is this a CUDA 9 compiler bug?
146  // K& ka = listK[pos];
147  // K& kb = listK[pos + stride];
148 
149  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
150  // swap(s, ka, kb);
151 
152  // V& va = listV[pos];
153  // V& vb = listV[pos + stride];
154  // swap(s, va, vb);
155  }
156 
157  __syncthreads();
158 
159  constexpr int kSecondLoopPerThread =
160  FullMerge ? kLoopPerThread : kLoopPerThread / 2;
161 
162 #pragma unroll
163  for (int stride = L / 2; stride > 0; stride /= 2) {
164 #pragma unroll
165  for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
166  int tid = loop * NumThreads + threadIdx.x;
167  int pos = 2 * tid - (tid & (stride - 1));
168 
169  K ka = listK[pos];
170  K kb = listK[pos + stride];
171 
172  // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
173  bool swap = Dir ? ka > kb : ka < kb;
174  listK[pos] = swap ? kb : ka;
175  listK[pos + stride] = swap ? ka : kb;
176 
177  V va = listV[pos];
178  V vb = listV[pos + stride];
179  listV[pos] = swap ? vb : va;
180  listV[pos + stride] = swap ? va : vb;
181 
182  // FIXME: is this a CUDA 9 compiler bug?
183  // K& ka = listK[pos];
184  // K& kb = listK[pos + stride];
185 
186  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
187  // swap(s, ka, kb);
188 
189  // V& va = listV[pos];
190  // V& vb = listV[pos + stride];
191  // swap(s, va, vb);
192  }
193 
194  __syncthreads();
195  }
196 }
197 
198 /// Class template to prevent static_assert from firing for
199 /// mixing smaller/larger than block cases
200 template <int NumThreads,
201  typename K,
202  typename V,
203  int N,
204  int L,
205  bool Dir,
206  bool SmallerThanBlock,
207  bool FullMerge>
208 struct BlockMerge {};
209 
210 /// Merging lists smaller than a block
211 template <int NumThreads,
212  typename K,
213  typename V,
214  int N,
215  int L,
216  bool Dir,
217  bool FullMerge>
218 struct BlockMerge<NumThreads, K, V, N, L, Dir, true, FullMerge> {
219  static inline __device__ void merge(K* listK, V* listV) {
220  constexpr int kNumParallelMerges = NumThreads / L;
221  constexpr int kNumIterations = N / kNumParallelMerges;
222 
223  static_assert(L <= NumThreads, "list must be <= NumThreads");
224  static_assert((N < kNumParallelMerges) ||
225  (kNumIterations * kNumParallelMerges == N),
226  "improper selection of N and L");
227 
228  if (N < kNumParallelMerges) {
229  // We only need L threads per each list to perform the merge
230  blockMergeSmall<NumThreads, K, V, N, L, false, Dir, FullMerge>(
231  listK, listV);
232  } else {
233  // All threads participate
234 #pragma unroll
235  for (int i = 0; i < kNumIterations; ++i) {
236  int start = i * kNumParallelMerges * 2 * L;
237 
238  blockMergeSmall<NumThreads, K, V, N, L, true, Dir, FullMerge>(
239  listK + start, listV + start);
240  }
241  }
242  }
243 };
244 
245 /// Merging lists larger than a block
246 template <int NumThreads,
247  typename K,
248  typename V,
249  int N,
250  int L,
251  bool Dir,
252  bool FullMerge>
253 struct BlockMerge<NumThreads, K, V, N, L, Dir, false, FullMerge> {
254  static inline __device__ void merge(K* listK, V* listV) {
255  // Each pair of lists is merged sequentially
256 #pragma unroll
257  for (int i = 0; i < N; ++i) {
258  int start = i * 2 * L;
259 
260  blockMergeLarge<NumThreads, K, V, L, Dir, FullMerge>(listK + start,
261  listV + start);
262  }
263  }
264 };
265 
266 template <int NumThreads,
267  typename K,
268  typename V,
269  int N,
270  int L,
271  bool Dir,
272  bool FullMerge = true>
273 inline __device__ void blockMerge(K* listK, V* listV) {
274  constexpr bool kSmallerThanBlock = (L <= NumThreads);
275 
276  BlockMerge<NumThreads, K, V, N, L, Dir, kSmallerThanBlock,
277  FullMerge>::merge(listK, listV);
278 }
279 
280 } // namespace core
281 } // namespace cloudViewer