1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "core/nns/kernel/BlockMerge.cuh"
11 #include "core/nns/kernel/MergeNetwork.cuh"
12 #include "core/nns/kernel/Pair.cuh"
13 #include "core/nns/kernel/PtxUtils.cuh"
14 #include "core/nns/kernel/Reduction.cuh"
15 #include "core/nns/kernel/StaticUtils.cuh"
17 namespace cloudViewer {
20 // Specialization for block-wide monotonic merges producing a merge sort
21 // since what we really want is a constexpr loop expansion
22 template <int NumWarps,
28 struct FinalBlockMerge {};
30 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
31 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir> {
32 static inline __device__ void merge(K* sharedK, V* sharedV) {
33 // no merge required; single warp
37 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
38 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir> {
39 static inline __device__ void merge(K* sharedK, V* sharedV) {
40 // Final merge doesn't need to fully merge the second list
41 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
42 !Dir, false>(sharedK, sharedV);
46 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
47 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir> {
48 static inline __device__ void merge(K* sharedK, V* sharedV) {
49 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
50 !Dir>(sharedK, sharedV);
51 // Final merge doesn't need to fully merge the second list
52 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4), NumWarpQ * 2,
53 !Dir, false>(sharedK, sharedV);
57 template <int NumThreads, typename K, typename V, int NumWarpQ, bool Dir>
58 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir> {
59 static inline __device__ void merge(K* sharedK, V* sharedV) {
60 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2), NumWarpQ,
61 !Dir>(sharedK, sharedV);
62 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4), NumWarpQ * 2,
63 !Dir>(sharedK, sharedV);
64 // Final merge doesn't need to fully merge the second list
65 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8), NumWarpQ * 4,
66 !Dir, false>(sharedK, sharedV);
70 // `Dir` true, produce largest values.
71 // `Dir` false, produce smallest values.
79 static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
80 static constexpr int kTotalWarpSortSize = NumWarpQ;
82 __device__ inline BlockSelect(
83 K initKVal, V initVVal, K* smemK, V* smemV, int k)
91 static_assert(isPowerOf2(ThreadsPerBlock),
92 "threads must be a power-of-2");
93 static_assert(isPowerOf2(NumWarpQ), "warp queue must be power-of-2");
95 // Fill the per-thread queue keys with the default value
97 for (int i = 0; i < NumThreadQ; ++i) {
102 int laneId = getLaneId();
103 int warpId = threadIdx.x / kWarpSize;
104 warpK = sharedK + warpId * kTotalWarpSortSize;
105 warpV = sharedV + warpId * kTotalWarpSortSize;
107 // Fill warp queue (only the actual queue space is fine, not where
108 // we write the per-thread queues for merging)
109 for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
117 __device__ inline void addThreadQ(K k, V v) {
118 // if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
119 if (Dir ? k > warpKTop : k < warpKTop) {
122 for (int i = NumThreadQ - 1; i > 0; --i) {
123 threadK[i] = threadK[i - 1];
124 threadV[i] = threadV[i - 1];
133 __device__ inline void checkThreadQ() {
134 bool needSort = (numVals == NumThreadQ);
136 #if CUDA_VERSION >= 9000
137 needSort = __any_sync(0xffffffff, needSort);
139 needSort = __any(needSort);
143 // no lanes have triggered a sort
147 // This has a trailing warpFence
150 // Any top-k elements have been merged into the warp queue; we're
151 // free to reset the thread queues
155 for (int i = 0; i < NumThreadQ; ++i) {
160 // We have to beat at least this element
161 warpKTop = warpK[kMinus1];
166 /// This function handles sorting and merging together the
167 /// per-thread queues with the warp-wide queue, creating a sorted
169 __device__ inline void mergeWarpQ() {
170 int laneId = getLaneId();
172 // Sort all of the per-thread queues
173 warpSortAnyRegisters<K, V, NumThreadQ, !Dir>(threadK, threadV);
175 constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
176 K warpKRegisters[kNumWarpQRegisters];
177 V warpVRegisters[kNumWarpQRegisters];
180 for (int i = 0; i < kNumWarpQRegisters; ++i) {
181 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
182 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
187 // The warp queue is already sorted, and now that we've sorted the
188 // per-thread queue, merge both sorted lists together, producing
190 warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir,
191 false>(warpKRegisters, warpVRegisters, threadK,
194 // Write back out the warp queue
196 for (int i = 0; i < kNumWarpQRegisters; ++i) {
197 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
198 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
204 /// WARNING: all threads in a warp must participate in this.
205 /// Otherwise, you must call the constituent parts separately.
206 __device__ inline void add(K k, V v) {
211 __device__ inline void reduce() {
212 // Have all warps dump and merge their queues; this will produce
213 // the final per-warp results
216 // block-wide dep; thus far, all warps have been completely
220 // All warp queues are contiguous in smem.
221 // Now, we have kNumWarps lists of NumWarpQ elements.
222 // This is a power of 2.
223 FinalBlockMerge<kNumWarps, ThreadsPerBlock, K, V, NumWarpQ, Dir>::merge(
226 // The block-wide merge has a trailing syncthreads
229 // Default element key
232 // Default element value
235 // Number of valid elements in our thread queue
238 // The k-th highest (Dir) or lowest (!Dir) element
241 // Thread queue values
242 K threadK[NumThreadQ];
243 V threadV[NumThreadQ];
245 // Queues for all warps
249 // Our warp's queue (points into sharedK/sharedV)
250 // warpK[0] is highest (Dir) or lowest (!Dir)
254 // This is a cached k-1 value
258 /// Specialization for k == 1 (NumWarpQ == 1)
259 template <typename K, typename V, bool Dir, int NumThreadQ, int ThreadsPerBlock>
260 struct BlockSelect<K, V, Dir, 1, NumThreadQ, ThreadsPerBlock> {
261 static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
263 __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k)
264 : threadK(initK), threadV(initV), sharedK(smemK), sharedV(smemV) {}
266 __device__ inline void addThreadQ(K k, V v) {
267 // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
268 bool swap = Dir ? k > threadK : k < threadK;
269 threadK = swap ? k : threadK;
270 threadV = swap ? v : threadV;
273 __device__ inline void checkThreadQ() {
274 // We don't need to do anything here, since the warp doesn't
275 // cooperate until the end
278 __device__ inline void add(K k, V v) { addThreadQ(k, v); }
280 __device__ inline void reduce() {
281 // Reduce within the warp
282 Pair<K, V> pair(threadK, threadV);
285 pair = warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(
286 pair, Max<Pair<K, V>>());
288 pair = warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(
289 pair, Min<Pair<K, V>>());
292 // Each warp writes out a single value
293 int laneId = getLaneId();
294 int warpId = threadIdx.x / kWarpSize;
297 sharedK[warpId] = pair.k;
298 sharedV[warpId] = pair.v;
303 // We typically use this for small blocks (<= 128), just having the
304 // first thread in the block perform the reduction across warps is
306 if (threadIdx.x == 0) {
307 threadK = sharedK[0];
308 threadV = sharedV[0];
311 for (int i = 1; i < kNumWarps; ++i) {
315 // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k,
317 bool swap = Dir ? k > threadK : k < threadK;
318 threadK = swap ? k : threadK;
319 threadV = swap ? v : threadV;
322 // Hopefully a thread's smem reads/writes are ordered wrt
323 // itself, so no barrier needed :)
324 sharedK[0] = threadK;
325 sharedV[0] = threadV;
328 // In case other threads wish to read this value
332 // threadK is lowest (Dir) or highest (!Dir)
336 // Where we reduce in smem
342 // per-warp WarpSelect
345 // `Dir` true, produce largest values.
346 // `Dir` false, produce smallest values.
347 template <typename K,
354 static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
356 __device__ inline WarpSelect(K initKVal, V initVVal, int k)
361 kLane((k - 1) % kWarpSize) {
362 static_assert(isPowerOf2(ThreadsPerBlock),
363 "threads must be a power-of-2");
364 static_assert(isPowerOf2(NumWarpQ), "warp queue must be power-of-2");
366 // Fill the per-thread queue keys with the default value
368 for (int i = 0; i < NumThreadQ; ++i) {
373 // Fill the warp queue with the default value
375 for (int i = 0; i < kNumWarpQRegisters; ++i) {
381 __device__ inline void addThreadQ(K k, V v) {
382 // if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
383 if (Dir ? k > warpKTop : k < warpKTop) {
386 for (int i = NumThreadQ - 1; i > 0; --i) {
387 threadK[i] = threadK[i - 1];
388 threadV[i] = threadV[i - 1];
397 __device__ inline void checkThreadQ() {
398 bool needSort = (numVals == NumThreadQ);
400 #if CUDA_VERSION >= 9000
401 needSort = __any_sync(0xffffffff, needSort);
403 needSort = __any(needSort);
407 // no lanes have triggered a sort
413 // Any top-k elements have been merged into the warp queue; we're
414 // free to reset the thread queues
418 for (int i = 0; i < NumThreadQ; ++i) {
423 // We have to beat at least this element
424 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
427 /// This function handles sorting and merging together the
428 /// per-thread queues with the warp-wide queue, creating a sorted
430 __device__ inline void mergeWarpQ() {
431 // Sort all of the per-thread queues
432 warpSortAnyRegisters<K, V, NumThreadQ, !Dir>(threadK, threadV);
434 // The warp queue is already sorted, and now that we've sorted the
435 // per-thread queue, merge both sorted lists together, producing
437 warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir,
438 false>(warpK, warpV, threadK, threadV);
441 /// WARNING: all threads in a warp must participate in this.
442 /// Otherwise, you must call the constituent parts separately.
443 __device__ inline void add(K k, V v) {
448 __device__ inline void reduce() {
449 // Have all warps dump and merge their queues; this will produce
450 // the final per-warp results
454 /// Dump final k selected values for this warp out
455 __device__ inline void writeOut(K* outK, V* outV, int k) {
456 int laneId = getLaneId();
459 for (int i = 0; i < kNumWarpQRegisters; ++i) {
460 int idx = i * kWarpSize + laneId;
463 outK[idx] = warpK[i];
464 outV[idx] = warpV[i];
469 // Default element key
472 // Default element value
475 // Number of valid elements in our thread queue
478 // The k-th highest (Dir) or lowest (!Dir) element
481 // Thread queue values
482 K threadK[NumThreadQ];
483 V threadV[NumThreadQ];
485 // warpK[0] is highest (Dir) or lowest (!Dir)
486 K warpK[kNumWarpQRegisters];
487 V warpV[kNumWarpQRegisters];
489 // This is what lane we should load an approximation (>=k) to the
490 // kth element from the last register in the warp queue (i.e.,
491 // warpK[kNumWarpQRegisters - 1]).
495 /// Specialization for k == 1 (NumWarpQ == 1)
496 template <typename K, typename V, bool Dir, int NumThreadQ, int ThreadsPerBlock>
497 struct WarpSelect<K, V, Dir, 1, NumThreadQ, ThreadsPerBlock> {
498 static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
500 __device__ inline WarpSelect(K initK, V initV, int k)
501 : threadK(initK), threadV(initV) {}
503 __device__ inline void addThreadQ(K k, V v) {
504 // bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
505 bool swap = Dir ? k > threadK : k < threadK;
506 threadK = swap ? k : threadK;
507 threadV = swap ? v : threadV;
510 __device__ inline void checkThreadQ() {
511 // We don't need to do anything here, since the warp doesn't
512 // cooperate until the end
515 __device__ inline void add(K k, V v) { addThreadQ(k, v); }
517 __device__ inline void reduce() {
518 // Reduce within the warp
519 Pair<K, V> pair(threadK, threadV);
522 pair = warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(
523 pair, Max<Pair<K, V>>());
525 pair = warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(
526 pair, Min<Pair<K, V>>());
533 /// Dump final k selected values for this warp out
534 __device__ inline void writeOut(K* outK, V* outV, int k) {
535 if (getLaneId() == 0) {
541 // threadK is lowest (Dir) or highest (!Dir)
547 } // namespace cloudViewer