1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "core/nns/kernel/PtxUtils.cuh"
11 #include "core/nns/kernel/StaticUtils.cuh"
12 #include "core/nns/kernel/WarpShuffle.cuh"
14 namespace cloudViewer {
18 // This file contains functions to:
20 // -perform bitonic merges on pairs of sorted lists, held in
21 // registers. Each list contains N * kWarpSize (multiple of 32)
22 // elements for some N.
23 // The bitonic merge is implemented for arbitrary sizes;
24 // sorted list A of size N1 * kWarpSize registers
25 // sorted list B of size N2 * kWarpSize registers =>
26 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
27 // are >= 1 and don't have to be powers of 2.
29 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
30 // held in registers, by using the above bitonic merge as a
32 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
33 // odd sizes and doesn't require the input to be a power of 2.
35 // The sort or merge network is completely statically instantiated via
36 // template specialization / expansion and constexpr, and it uses warp
37 // shuffles to exchange values between warp lanes.
39 // A note about comparisons:
41 // For a sorting network of keys only, we only need one
42 // comparison (a < b). However, what we really need to know is
43 // if one lane chooses to exchange a value, then the
44 // corresponding lane should also do the exchange.
45 // Thus, if one just uses the negation !(x < y) in the higher
46 // lane, this will also include the case where (x == y). Thus, one
47 // lane in fact performs an exchange and the other doesn't, but
48 // because the only value being exchanged is equivalent, nothing has
50 // So, you can get away with just one comparison and its negation.
52 // If we're sorting keys and values, where equivalent keys can
53 // exist, then this is a problem, since we want to treat (x, v1)
54 // as not equivalent to (x, v2).
56 // To remedy this, you can either compare with a lexicographic
57 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
58 // we're predicating all of the choices results in 3 comparisons
59 // being executed, or we can invert the selection so that there is no
60 // middle choice of equality; the other lane will likewise
61 // check that (b.k > a.k) (the higher lane has the values
62 // swapped). Then, the first lane swaps if and only if the
63 // second lane swaps; if both lanes have equivalent keys, no
64 // swap will be performed. This results in only two comparisons
67 // If you don't consider values as well, then this does not produce a
68 // consistent ordering among (k, v) pairs with equivalent keys but
69 // different values; for us, we don't really care about ordering or
72 // I have tried both re-arranging the order in the higher lane to get
73 // away with one comparison or adding the value to the check; both
74 // result in greater register consumption or lower speed than just
75 // performing both < and > comparisons with the variables, so I just
79 inline __device__ void swap(bool swap, T& x, T& y) {
86 inline __device__ void assign(bool assign, T& x, T y) {
90 // This function merges kWarpSize / 2L lists in parallel using warp
92 // It works on at most size-16 lists, as we need 32 threads for this
95 // If IsBitonic is false, the first stage is reversed, so we don't
96 // need to sort directionally. It's still technically a bitonic sort.
97 template <typename K, typename V, int L, bool Dir, bool IsBitonic>
98 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
99 static_assert(isPowerOf2(L), "L must be a power-of-2");
100 static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
102 int lane_id = getLaneId();
105 // Reverse the first comparison stage.
106 // For example, merging a list of size 8 has the exchanges:
107 // 0 <-> 15, 1 <-> 14, ...
108 K otherK = shfl_xor(k, 2 * L - 1);
109 V otherV = shfl_xor(v, 2 * L - 1);
111 // Whether we are the lesser thread in the exchange
112 bool is_small = (lane_id & L) == 0;
115 // See the comment above how performing both of these
116 // comparisons in the warp seems to win out over the
117 // alternatives in practice
118 // bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
119 bool s = is_small ? (k > otherK) : (k < otherK);
120 assign(s, k, otherK);
121 assign(s, v, otherV);
124 bool s = is_small ? (k < otherK) : (k > otherK);
125 assign(s, k, otherK);
126 assign(s, v, otherV);
131 for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
132 K otherK = shfl_xor(k, stride);
133 V otherV = shfl_xor(v, stride);
135 // Whether we are the lesser thread in the exchange
136 bool is_small = (lane_id & stride) == 0;
139 // bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
140 bool s = is_small ? (k > otherK) : (k < otherK);
141 assign(s, k, otherK);
142 assign(s, v, otherV);
145 // bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
146 bool s = is_small ? (k < otherK) : (k > otherK);
147 assign(s, k, otherK);
148 assign(s, v, otherV);
153 // Template for performing a bitonic merge of an arbitrary set of
155 template <typename K, typename V, int N, bool Dir, bool Low, bool Pow2>
156 struct BitonicMergeStep {};
159 // Power-of-2 merge specialization
162 // All merges eventually call this
163 template <typename K, typename V, bool Dir, bool Low>
164 struct BitonicMergeStep<K, V, 1, Dir, Low, true> {
165 static inline __device__ void merge(K k[1], V v[1]) {
167 warpBitonicMergeLE16<K, V, 16, Dir, true>(k[0], v[0]);
171 template <typename K, typename V, int N, bool Dir, bool Low>
172 struct BitonicMergeStep<K, V, N, Dir, Low, true> {
173 static inline __device__ void merge(K k[N], V v[N]) {
174 static_assert(isPowerOf2(N), "must be power of 2");
175 static_assert(N > 1, "must be N > 1");
178 for (int i = 0; i < N / 2; ++i) {
182 K& kb = k[i + N / 2];
183 V& vb = v[i + N / 2];
185 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
186 bool s = Dir ? ka > kb : ka < kb;
196 for (int i = 0; i < N / 2; ++i) {
201 BitonicMergeStep<K, V, N / 2, Dir, true, true>::merge(newK, newV);
204 for (int i = 0; i < N / 2; ++i) {
215 for (int i = 0; i < N / 2; ++i) {
216 newK[i] = k[i + N / 2];
217 newV[i] = v[i + N / 2];
220 BitonicMergeStep<K, V, N / 2, Dir, false, true>::merge(newK, newV);
223 for (int i = 0; i < N / 2; ++i) {
224 k[i + N / 2] = newK[i];
225 v[i + N / 2] = newV[i];
232 // Non-power-of-2 merge specialization
236 template <typename K, typename V, int N, bool Dir>
237 struct BitonicMergeStep<K, V, N, Dir, true, false> {
238 static inline __device__ void merge(K k[N], V v[N]) {
239 static_assert(!isPowerOf2(N), "must be non-power-of-2");
240 static_assert(N >= 3, "must be N >= 3");
242 constexpr int kNextHighestPowerOf2 = nextHighestPowerOf2(N);
245 for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
249 K& kb = k[i + kNextHighestPowerOf2 / 2];
250 V& vb = v[i + kNextHighestPowerOf2 / 2];
252 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
253 bool s = Dir ? ka > kb : ka < kb;
258 constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
259 constexpr int kHighSize = kNextHighestPowerOf2 / 2;
265 for (int i = 0; i < kLowSize; ++i) {
270 constexpr bool kLowIsPowerOf2 =
271 isPowerOf2(N - kNextHighestPowerOf2 / 2);
272 // FIXME: compiler doesn't like this expression? compiler bug?
273 // constexpr bool kLowIsPowerOf2 = isPowerOf2(kLowSize);
274 BitonicMergeStep<K, V, kLowSize, Dir,
276 kLowIsPowerOf2>::merge(newK, newV);
279 for (int i = 0; i < kLowSize; ++i) {
290 for (int i = 0; i < kHighSize; ++i) {
291 newK[i] = k[i + kLowSize];
292 newV[i] = v[i + kLowSize];
295 constexpr bool kHighIsPowerOf2 =
296 isPowerOf2(kNextHighestPowerOf2 / 2);
297 // FIXME: compiler doesn't like this expression? compiler bug?
298 // constexpr bool kHighIsPowerOf2 =
299 // isPowerOf2(kHighSize);
300 BitonicMergeStep<K, V, kHighSize, Dir,
302 kHighIsPowerOf2>::merge(newK, newV);
305 for (int i = 0; i < kHighSize; ++i) {
306 k[i + kLowSize] = newK[i];
307 v[i + kLowSize] = newV[i];
314 template <typename K, typename V, int N, bool Dir>
315 struct BitonicMergeStep<K, V, N, Dir, false, false> {
316 static inline __device__ void merge(K k[N], V v[N]) {
317 static_assert(!isPowerOf2(N), "must be non-power-of-2");
318 static_assert(N >= 3, "must be N >= 3");
320 constexpr int kNextHighestPowerOf2 = nextHighestPowerOf2(N);
323 for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
327 K& kb = k[i + kNextHighestPowerOf2 / 2];
328 V& vb = v[i + kNextHighestPowerOf2 / 2];
330 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
331 bool s = Dir ? ka > kb : ka < kb;
336 constexpr int kLowSize = kNextHighestPowerOf2 / 2;
337 constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
343 for (int i = 0; i < kLowSize; ++i) {
348 constexpr bool kLowIsPowerOf2 =
349 isPowerOf2(kNextHighestPowerOf2 / 2);
350 // FIXME: compiler doesn't like this expression? compiler bug?
351 // constexpr bool kLowIsPowerOf2 = isPowerOf2(kLowSize);
352 BitonicMergeStep<K, V, kLowSize, Dir,
354 kLowIsPowerOf2>::merge(newK, newV);
357 for (int i = 0; i < kLowSize; ++i) {
368 for (int i = 0; i < kHighSize; ++i) {
369 newK[i] = k[i + kLowSize];
370 newV[i] = v[i + kLowSize];
373 constexpr bool kHighIsPowerOf2 =
374 isPowerOf2(N - kNextHighestPowerOf2 / 2);
375 // FIXME: compiler doesn't like this expression? compiler bug?
376 // constexpr bool kHighIsPowerOf2 =
377 // isPowerOf2(kHighSize);
378 BitonicMergeStep<K, V, kHighSize, Dir,
380 kHighIsPowerOf2>::merge(newK, newV);
383 for (int i = 0; i < kHighSize; ++i) {
384 k[i + kLowSize] = newK[i];
385 v[i + kLowSize] = newV[i];
391 /// Merges two sets of registers across the warp of any size;
392 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
393 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
395 template <typename K,
400 bool FullMerge = true>
401 inline __device__ void warpMergeAnyRegisters(K k1[N1],
405 constexpr int kSmallestN = N1 < N2 ? N1 : N2;
408 for (int i = 0; i < kSmallestN; ++i) {
409 K& ka = k1[N1 - 1 - i];
410 V& va = v1[N1 - 1 - i];
419 // We need the other values
420 otherKa = shfl_xor(ka, kWarpSize - 1);
421 otherVa = shfl_xor(va, kWarpSize - 1);
424 K otherKb = shfl_xor(kb, kWarpSize - 1);
425 V otherVb = shfl_xor(vb, kWarpSize - 1);
427 // ka is always first in the list, so we needn't use our lane
428 // in this comparison
429 // bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
430 bool swapa = Dir ? ka > otherKb : ka < otherKb;
431 assign(swapa, ka, otherKb);
432 assign(swapa, va, otherVb);
434 // kb is always second in the list, so we needn't use our lane
435 // in this comparison
437 // bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
438 bool swapb = Dir ? kb < otherKa : kb > otherKa;
439 assign(swapb, kb, otherKa);
440 assign(swapb, vb, otherVa);
443 // We don't care about updating elements in the second list
447 BitonicMergeStep<K, V, N1, Dir, true, isPowerOf2(N1)>::merge(k1, v1);
449 // Only if we care about N2 do we need to bother merging it fully
450 BitonicMergeStep<K, V, N2, Dir, false, isPowerOf2(N2)>::merge(k2, v2);
454 // Recursive template that uses the above bitonic merge to perform a
456 template <typename K, typename V, int N, bool Dir>
457 struct BitonicSortStep {
458 static inline __device__ void sort(K k[N], V v[N]) {
459 static_assert(N > 1, "did not hit specialized case");
462 constexpr int kSizeA = N / 2;
463 constexpr int kSizeB = N - kSizeA;
469 for (int i = 0; i < kSizeA; ++i) {
474 BitonicSortStep<K, V, kSizeA, Dir>::sort(aK, aV);
480 for (int i = 0; i < kSizeB; ++i) {
481 bK[i] = k[i + kSizeA];
482 bV[i] = v[i + kSizeA];
485 BitonicSortStep<K, V, kSizeB, Dir>::sort(bK, bV);
488 warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir>(aK, aV, bK, bV);
491 for (int i = 0; i < kSizeA; ++i) {
497 for (int i = 0; i < kSizeB; ++i) {
498 k[i + kSizeA] = bK[i];
499 v[i + kSizeA] = bV[i];
504 // Single warp (N == 1) sorting specialization
505 template <typename K, typename V, bool Dir>
506 struct BitonicSortStep<K, V, 1, Dir> {
507 static inline __device__ void sort(K k[1], V v[1]) {
508 // Update this code if this changes
509 // should go from 1 -> kWarpSize in multiples of 2
510 static_assert(kWarpSize == 32, "unexpected warp size");
512 warpBitonicMergeLE16<K, V, 1, Dir, false>(k[0], v[0]);
513 warpBitonicMergeLE16<K, V, 2, Dir, false>(k[0], v[0]);
514 warpBitonicMergeLE16<K, V, 4, Dir, false>(k[0], v[0]);
515 warpBitonicMergeLE16<K, V, 8, Dir, false>(k[0], v[0]);
516 warpBitonicMergeLE16<K, V, 16, Dir, false>(k[0], v[0]);
520 /// Sort a list of kWarpSize * N elements in registers, where N is an
522 template <typename K, typename V, int N, bool Dir>
523 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
524 BitonicSortStep<K, V, N, Dir>::sort(k, v);
528 } // namespace cloudViewer