ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
Select.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "core/nns/kernel/BlockMerge.cuh"
11 #include "core/nns/kernel/MergeNetwork.cuh"
12 #include "core/nns/kernel/Pair.cuh"
13 #include "core/nns/kernel/PtxUtils.cuh"
14 #include "core/nns/kernel/Reduction.cuh"
15 #include "core/nns/kernel/StaticUtils.cuh"
16 
17 namespace cloudViewer {
18 namespace core {
19 
20 // Specialization for block-wide monotonic merges producing a merge sort
21 // since what we really want is a constexpr loop expansion
22 template <int NumWarps,
23  int NumThreads,
24  typename K,
25  typename V,
26  int NumWarpQ,
27  bool Dir>
28 struct FinalBlockMerge {};
29 
30 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
31 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir> {
32  static inline __device__ void merge(K* sharedK, V* sharedV) {
33  // no merge required; single warp
34  }
35 };
36 
37 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
38 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir> {
39  static inline __device__ void merge(K* sharedK, V* sharedV) {
40  // Final merge doesn't need to fully merge the second list
41  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
42  !Dir, false>(sharedK, sharedV);
43  }
44 };
45 
46 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
47 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir> {
48  static inline __device__ void merge(K* sharedK, V* sharedV) {
49  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
50  !Dir>(sharedK, sharedV);
51  // Final merge doesn't need to fully merge the second list
52  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4), NumWarpQ * 2,
53  !Dir, false>(sharedK, sharedV);
54  }
55 };
56 
57 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
58 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir> {
59  static inline __device__ void merge(K* sharedK, V* sharedV) {
60  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
61  !Dir>(sharedK, sharedV);
62  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4), NumWarpQ * 2,
63  !Dir>(sharedK, sharedV);
64  // Final merge doesn't need to fully merge the second list
65  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8), NumWarpQ * 4,
66  !Dir, false>(sharedK, sharedV);
67  }
68 };
69 
70 // `Dir` true, produce largest values.
71 // `Dir` false, produce smallest values.
72 template <typename K,
73  typename V,
74  bool Dir,
75  int NumWarpQ,
76  int NumThreadQ,
77  int ThreadsPerBlock>
78 struct BlockSelect {
79  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
80  static constexpr int kTotalWarpSortSize = NumWarpQ;
81 
82  __device__ inline BlockSelect(
83  K initKVal, V initVVal, K* smemK, V* smemV, int k)
84  : initK(initKVal),
85  initV(initVVal),
86  numVals(0),
87  warpKTop(initKVal),
88  sharedK(smemK),
89  sharedV(smemV),
90  kMinus1(k - 1) {
91  static_assert(isPowerOf2(ThreadsPerBlock),
92  "threads must be a power-of-2");
93  static_assert(isPowerOf2(NumWarpQ), "warp queue must be power-of-2");
94 
95  // Fill the per-thread queue keys with the default value
96 #pragma unroll
97  for (int i = 0; i < NumThreadQ; ++i) {
98  threadK[i] = initK;
99  threadV[i] = initV;
100  }
101 
102  int laneId = getLaneId();
103  int warpId = threadIdx.x / kWarpSize;
104  warpK = sharedK + warpId * kTotalWarpSortSize;
105  warpV = sharedV + warpId * kTotalWarpSortSize;
106 
107  // Fill warp queue (only the actual queue space is fine, not where
108  // we write the per-thread queues for merging)
109  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
110  warpK[i] = initK;
111  warpV[i] = initV;
112  }
113 
114  warpFence();
115  }
116 
117  __device__ inline void addThreadQ(K k, V v) {
118  // if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
119  if (Dir ? k > warpKTop : k < warpKTop) {
120  // Rotate right
121 #pragma unroll
122  for (int i = NumThreadQ - 1; i > 0; --i) {
123  threadK[i] = threadK[i - 1];
124  threadV[i] = threadV[i - 1];
125  }
126 
127  threadK[0] = k;
128  threadV[0] = v;
129  ++numVals;
130  }
131  }
132 
133  __device__ inline void checkThreadQ() {
134  bool needSort = (numVals == NumThreadQ);
135 
136 #if CUDA_VERSION >= 9000
137  needSort = __any_sync(0xffffffff, needSort);
138 #else
139  needSort = __any(needSort);
140 #endif
141 
142  if (!needSort) {
143  // no lanes have triggered a sort
144  return;
145  }
146 
147  // This has a trailing warpFence
148  mergeWarpQ();
149 
150  // Any top-k elements have been merged into the warp queue; we're
151  // free to reset the thread queues
152  numVals = 0;
153 
154 #pragma unroll
155  for (int i = 0; i < NumThreadQ; ++i) {
156  threadK[i] = initK;
157  threadV[i] = initV;
158  }
159 
160  // We have to beat at least this element
161  warpKTop = warpK[kMinus1];
162 
163  warpFence();
164  }
165 
166  /// This function handles sorting and merging together the
167  /// per-thread queues with the warp-wide queue, creating a sorted
168  /// list across both
169  __device__ inline void mergeWarpQ() {
170  int laneId = getLaneId();
171 
172  // Sort all of the per-thread queues
173  warpSortAnyRegisters<K, V, NumThreadQ, !Dir>(threadK, threadV);
174 
175  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
176  K warpKRegisters[kNumWarpQRegisters];
177  V warpVRegisters[kNumWarpQRegisters];
178 
179 #pragma unroll
180  for (int i = 0; i < kNumWarpQRegisters; ++i) {
181  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
182  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
183  }
184 
185  warpFence();
186 
187  // The warp queue is already sorted, and now that we've sorted the
188  // per-thread queue, merge both sorted lists together, producing
189  // one sorted list
190  warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir,
191  false>(warpKRegisters, warpVRegisters, threadK,
192  threadV);
193 
194  // Write back out the warp queue
195 #pragma unroll
196  for (int i = 0; i < kNumWarpQRegisters; ++i) {
197  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
198  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
199  }
200 
201  warpFence();
202  }
203 
204  /// WARNING: all threads in a warp must participate in this.
205  /// Otherwise, you must call the constituent parts separately.
206  __device__ inline void add(K k, V v) {
207  addThreadQ(k, v);
208  checkThreadQ();
209  }
210 
211  __device__ inline void reduce() {
212  // Have all warps dump and merge their queues; this will produce
213  // the final per-warp results
214  mergeWarpQ();
215 
216  // block-wide dep; thus far, all warps have been completely
217  // independent
218  __syncthreads();
219 
220  // All warp queues are contiguous in smem.
221  // Now, we have kNumWarps lists of NumWarpQ elements.
222  // This is a power of 2.
223  FinalBlockMerge<kNumWarps, ThreadsPerBlock, K, V, NumWarpQ, Dir>::merge(
224  sharedK, sharedV);
225 
226  // The block-wide merge has a trailing syncthreads
227  }
228 
229  // Default element key
230  const K initK;
231 
232  // Default element value
233  const V initV;
234 
235  // Number of valid elements in our thread queue
236  int numVals;
237 
238  // The k-th highest (Dir) or lowest (!Dir) element
239  K warpKTop;
240 
241  // Thread queue values
242  K threadK[NumThreadQ];
243  V threadV[NumThreadQ];
244 
245  // Queues for all warps
246  K* sharedK;
247  V* sharedV;
248 
249  // Our warp's queue (points into sharedK/sharedV)
250  // warpK[0] is highest (Dir) or lowest (!Dir)
251  K* warpK;
252  V* warpV;
253 
254  // This is a cached k-1 value
255  int kMinus1;
256 };
257 
258 /// Specialization for k == 1 (NumWarpQ == 1)
259 template <typename K, typename V, bool Dir, int NumThreadQ, int ThreadsPerBlock>
260 struct BlockSelect<K, V, Dir, 1, NumThreadQ, ThreadsPerBlock> {
261  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
262 
263  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k)
264  : threadK(initK), threadV(initV), sharedK(smemK), sharedV(smemV) {}
265 
266  __device__ inline void addThreadQ(K k, V v) {
267  // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
268  bool swap = Dir ? k > threadK : k < threadK;
269  threadK = swap ? k : threadK;
270  threadV = swap ? v : threadV;
271  }
272 
273  __device__ inline void checkThreadQ() {
274  // We don't need to do anything here, since the warp doesn't
275  // cooperate until the end
276  }
277 
278  __device__ inline void add(K k, V v) { addThreadQ(k, v); }
279 
280  __device__ inline void reduce() {
281  // Reduce within the warp
282  Pair<K, V> pair(threadK, threadV);
283 
284  if (Dir) {
285  pair = warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(
286  pair, Max<Pair<K, V>>());
287  } else {
288  pair = warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(
289  pair, Min<Pair<K, V>>());
290  }
291 
292  // Each warp writes out a single value
293  int laneId = getLaneId();
294  int warpId = threadIdx.x / kWarpSize;
295 
296  if (laneId == 0) {
297  sharedK[warpId] = pair.k;
298  sharedV[warpId] = pair.v;
299  }
300 
301  __syncthreads();
302 
303  // We typically use this for small blocks (<= 128), just having the
304  // first thread in the block perform the reduction across warps is
305  // faster
306  if (threadIdx.x == 0) {
307  threadK = sharedK[0];
308  threadV = sharedV[0];
309 
310 #pragma unroll
311  for (int i = 1; i < kNumWarps; ++i) {
312  K k = sharedK[i];
313  V v = sharedV[i];
314 
315  // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k,
316  // threadK);
317  bool swap = Dir ? k > threadK : k < threadK;
318  threadK = swap ? k : threadK;
319  threadV = swap ? v : threadV;
320  }
321 
322  // Hopefully a thread's smem reads/writes are ordered wrt
323  // itself, so no barrier needed :)
324  sharedK[0] = threadK;
325  sharedV[0] = threadV;
326  }
327 
328  // In case other threads wish to read this value
329  __syncthreads();
330  }
331 
332  // threadK is lowest (Dir) or highest (!Dir)
333  K threadK;
334  V threadV;
335 
336  // Where we reduce in smem
337  K* sharedK;
338  V* sharedV;
339 };
340 
341 //
342 // per-warp WarpSelect
343 //
344 
345 // `Dir` true, produce largest values.
346 // `Dir` false, produce smallest values.
347 template <typename K,
348  typename V,
349  bool Dir,
350  int NumWarpQ,
351  int NumThreadQ,
352  int ThreadsPerBlock>
353 struct WarpSelect {
354  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
355 
356  __device__ inline WarpSelect(K initKVal, V initVVal, int k)
357  : initK(initKVal),
358  initV(initVVal),
359  numVals(0),
360  warpKTop(initKVal),
361  kLane((k - 1) % kWarpSize) {
362  static_assert(isPowerOf2(ThreadsPerBlock),
363  "threads must be a power-of-2");
364  static_assert(isPowerOf2(NumWarpQ), "warp queue must be power-of-2");
365 
366  // Fill the per-thread queue keys with the default value
367 #pragma unroll
368  for (int i = 0; i < NumThreadQ; ++i) {
369  threadK[i] = initK;
370  threadV[i] = initV;
371  }
372 
373  // Fill the warp queue with the default value
374 #pragma unroll
375  for (int i = 0; i < kNumWarpQRegisters; ++i) {
376  warpK[i] = initK;
377  warpV[i] = initV;
378  }
379  }
380 
381  __device__ inline void addThreadQ(K k, V v) {
382  // if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
383  if (Dir ? k > warpKTop : k < warpKTop) {
384  // Rotate right
385 #pragma unroll
386  for (int i = NumThreadQ - 1; i > 0; --i) {
387  threadK[i] = threadK[i - 1];
388  threadV[i] = threadV[i - 1];
389  }
390 
391  threadK[0] = k;
392  threadV[0] = v;
393  ++numVals;
394  }
395  }
396 
397  __device__ inline void checkThreadQ() {
398  bool needSort = (numVals == NumThreadQ);
399 
400 #if CUDA_VERSION >= 9000
401  needSort = __any_sync(0xffffffff, needSort);
402 #else
403  needSort = __any(needSort);
404 #endif
405 
406  if (!needSort) {
407  // no lanes have triggered a sort
408  return;
409  }
410 
411  mergeWarpQ();
412 
413  // Any top-k elements have been merged into the warp queue; we're
414  // free to reset the thread queues
415  numVals = 0;
416 
417 #pragma unroll
418  for (int i = 0; i < NumThreadQ; ++i) {
419  threadK[i] = initK;
420  threadV[i] = initV;
421  }
422 
423  // We have to beat at least this element
424  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
425  }
426 
427  /// This function handles sorting and merging together the
428  /// per-thread queues with the warp-wide queue, creating a sorted
429  /// list across both
430  __device__ inline void mergeWarpQ() {
431  // Sort all of the per-thread queues
432  warpSortAnyRegisters<K, V, NumThreadQ, !Dir>(threadK, threadV);
433 
434  // The warp queue is already sorted, and now that we've sorted the
435  // per-thread queue, merge both sorted lists together, producing
436  // one sorted list
437  warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir,
438  false>(warpK, warpV, threadK, threadV);
439  }
440 
441  /// WARNING: all threads in a warp must participate in this.
442  /// Otherwise, you must call the constituent parts separately.
443  __device__ inline void add(K k, V v) {
444  addThreadQ(k, v);
445  checkThreadQ();
446  }
447 
448  __device__ inline void reduce() {
449  // Have all warps dump and merge their queues; this will produce
450  // the final per-warp results
451  mergeWarpQ();
452  }
453 
454  /// Dump final k selected values for this warp out
455  __device__ inline void writeOut(K* outK, V* outV, int k) {
456  int laneId = getLaneId();
457 
458 #pragma unroll
459  for (int i = 0; i < kNumWarpQRegisters; ++i) {
460  int idx = i * kWarpSize + laneId;
461 
462  if (idx < k) {
463  outK[idx] = warpK[i];
464  outV[idx] = warpV[i];
465  }
466  }
467  }
468 
469  // Default element key
470  const K initK;
471 
472  // Default element value
473  const V initV;
474 
475  // Number of valid elements in our thread queue
476  int numVals;
477 
478  // The k-th highest (Dir) or lowest (!Dir) element
479  K warpKTop;
480 
481  // Thread queue values
482  K threadK[NumThreadQ];
483  V threadV[NumThreadQ];
484 
485  // warpK[0] is highest (Dir) or lowest (!Dir)
486  K warpK[kNumWarpQRegisters];
487  V warpV[kNumWarpQRegisters];
488 
489  // This is what lane we should load an approximation (>=k) to the
490  // kth element from the last register in the warp queue (i.e.,
491  // warpK[kNumWarpQRegisters - 1]).
492  int kLane;
493 };
494 
495 /// Specialization for k == 1 (NumWarpQ == 1)
496 template <typename K, typename V, bool Dir, int NumThreadQ, int ThreadsPerBlock>
497 struct WarpSelect<K, V, Dir, 1, NumThreadQ, ThreadsPerBlock> {
498  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
499 
500  __device__ inline WarpSelect(K initK, V initV, int k)
501  : threadK(initK), threadV(initV) {}
502 
503  __device__ inline void addThreadQ(K k, V v) {
504  // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
505  bool swap = Dir ? k > threadK : k < threadK;
506  threadK = swap ? k : threadK;
507  threadV = swap ? v : threadV;
508  }
509 
510  __device__ inline void checkThreadQ() {
511  // We don't need to do anything here, since the warp doesn't
512  // cooperate until the end
513  }
514 
515  __device__ inline void add(K k, V v) { addThreadQ(k, v); }
516 
517  __device__ inline void reduce() {
518  // Reduce within the warp
519  Pair<K, V> pair(threadK, threadV);
520 
521  if (Dir) {
522  pair = warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(
523  pair, Max<Pair<K, V>>());
524  } else {
525  pair = warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(
526  pair, Min<Pair<K, V>>());
527  }
528 
529  threadK = pair.k;
530  threadV = pair.v;
531  }
532 
533  /// Dump final k selected values for this warp out
534  __device__ inline void writeOut(K* outK, V* outV, int k) {
535  if (getLaneId() == 0) {
536  *outK = threadK;
537  *outV = threadV;
538  }
539  }
540 
541  // threadK is lowest (Dir) or highest (!Dir)
542  K threadK;
543  V threadV;
544 };
545 
546 } // namespace core
547 } // namespace cloudViewer