ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
MergeNetwork.cuh
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "core/nns/kernel/PtxUtils.cuh"
11 #include "core/nns/kernel/StaticUtils.cuh"
12 #include "core/nns/kernel/WarpShuffle.cuh"
13 
14 namespace cloudViewer {
15 namespace core {
16 
17 //
18 // This file contains functions to:
19 //
20 // -perform bitonic merges on pairs of sorted lists, held in
21 // registers. Each list contains N * kWarpSize (multiple of 32)
22 // elements for some N.
23 // The bitonic merge is implemented for arbitrary sizes;
24 // sorted list A of size N1 * kWarpSize registers
25 // sorted list B of size N2 * kWarpSize registers =>
26 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
27 // are >= 1 and don't have to be powers of 2.
28 //
29 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
30 // held in registers, by using the above bitonic merge as a
31 // primitive.
32 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
33 // odd sizes and doesn't require the input to be a power of 2.
34 //
35 // The sort or merge network is completely statically instantiated via
36 // template specialization / expansion and constexpr, and it uses warp
37 // shuffles to exchange values between warp lanes.
38 //
39 // A note about comparisons:
40 //
41 // For a sorting network of keys only, we only need one
42 // comparison (a < b). However, what we really need to know is
43 // if one lane chooses to exchange a value, then the
44 // corresponding lane should also do the exchange.
45 // Thus, if one just uses the negation !(x < y) in the higher
46 // lane, this will also include the case where (x == y). Thus, one
47 // lane in fact performs an exchange and the other doesn't, but
48 // because the only value being exchanged is equivalent, nothing has
49 // changed.
50 // So, you can get away with just one comparison and its negation.
51 //
52 // If we're sorting keys and values, where equivalent keys can
53 // exist, then this is a problem, since we want to treat (x, v1)
54 // as not equivalent to (x, v2).
55 //
56 // To remedy this, you can either compare with a lexicographic
57 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
58 // we're predicating all of the choices results in 3 comparisons
59 // being executed, or we can invert the selection so that there is no
60 // middle choice of equality; the other lane will likewise
61 // check that (b.k > a.k) (the higher lane has the values
62 // swapped). Then, the first lane swaps if and only if the
63 // second lane swaps; if both lanes have equivalent keys, no
64 // swap will be performed. This results in only two comparisons
65 // being executed.
66 //
67 // If you don't consider values as well, then this does not produce a
68 // consistent ordering among (k, v) pairs with equivalent keys but
69 // different values; for us, we don't really care about ordering or
70 // stability here.
71 //
72 // I have tried both re-arranging the order in the higher lane to get
73 // away with one comparison or adding the value to the check; both
74 // result in greater register consumption or lower speed than just
75 // performing both < and > comparisons with the variables, so I just
76 // stick with this.
77 
78 template <typename T>
79 inline __device__ void swap(bool swap, T& x, T& y) {
80  T tmp = x;
81  x = swap ? y : x;
82  y = swap ? tmp : y;
83 }
84 
85 template <typename T>
86 inline __device__ void assign(bool assign, T& x, T y) {
87  x = assign ? y : x;
88 }
89 
90 // This function merges kWarpSize / 2L lists in parallel using warp
91 // shuffles.
92 // It works on at most size-16 lists, as we need 32 threads for this
93 // shuffle merge.
94 //
95 // If IsBitonic is false, the first stage is reversed, so we don't
96 // need to sort directionally. It's still technically a bitonic sort.
97 template <typename K, typename V, int L, bool Dir, bool IsBitonic>
98 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
99  static_assert(isPowerOf2(L), "L must be a power-of-2");
100  static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
101 
102  int lane_id = getLaneId();
103 
104  if (!IsBitonic) {
105  // Reverse the first comparison stage.
106  // For example, merging a list of size 8 has the exchanges:
107  // 0 <-> 15, 1 <-> 14, ...
108  K otherK = shfl_xor(k, 2 * L - 1);
109  V otherV = shfl_xor(v, 2 * L - 1);
110 
111  // Whether we are the lesser thread in the exchange
112  bool is_small = (lane_id & L) == 0;
113 
114  if (Dir) {
115  // See the comment above how performing both of these
116  // comparisons in the warp seems to win out over the
117  // alternatives in practice
118  // bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
119  bool s = is_small ? (k > otherK) : (k < otherK);
120  assign(s, k, otherK);
121  assign(s, v, otherV);
122 
123  } else {
124  bool s = is_small ? (k < otherK) : (k > otherK);
125  assign(s, k, otherK);
126  assign(s, v, otherV);
127  }
128  }
129 
130 #pragma unroll
131  for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
132  K otherK = shfl_xor(k, stride);
133  V otherV = shfl_xor(v, stride);
134 
135  // Whether we are the lesser thread in the exchange
136  bool is_small = (lane_id & stride) == 0;
137 
138  if (Dir) {
139  // bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
140  bool s = is_small ? (k > otherK) : (k < otherK);
141  assign(s, k, otherK);
142  assign(s, v, otherV);
143 
144  } else {
145  // bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
146  bool s = is_small ? (k < otherK) : (k > otherK);
147  assign(s, k, otherK);
148  assign(s, v, otherV);
149  }
150  }
151 }
152 
153 // Template for performing a bitonic merge of an arbitrary set of
154 // registers
155 template <typename K, typename V, int N, bool Dir, bool Low, bool Pow2>
156 struct BitonicMergeStep {};
157 
158 //
159 // Power-of-2 merge specialization
160 //
161 
162 // All merges eventually call this
163 template <typename K, typename V, bool Dir, bool Low>
164 struct BitonicMergeStep<K, V, 1, Dir, Low, true> {
165  static inline __device__ void merge(K k[1], V v[1]) {
166  // Use warp shuffles
167  warpBitonicMergeLE16<K, V, 16, Dir, true>(k[0], v[0]);
168  }
169 };
170 
171 template <typename K, typename V, int N, bool Dir, bool Low>
172 struct BitonicMergeStep<K, V, N, Dir, Low, true> {
173  static inline __device__ void merge(K k[N], V v[N]) {
174  static_assert(isPowerOf2(N), "must be power of 2");
175  static_assert(N > 1, "must be N > 1");
176 
177 #pragma unroll
178  for (int i = 0; i < N / 2; ++i) {
179  K& ka = k[i];
180  V& va = v[i];
181 
182  K& kb = k[i + N / 2];
183  V& vb = v[i + N / 2];
184 
185  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
186  bool s = Dir ? ka > kb : ka < kb;
187  swap(s, ka, kb);
188  swap(s, va, vb);
189  }
190 
191  {
192  K newK[N / 2];
193  V newV[N / 2];
194 
195 #pragma unroll
196  for (int i = 0; i < N / 2; ++i) {
197  newK[i] = k[i];
198  newV[i] = v[i];
199  }
200 
201  BitonicMergeStep<K, V, N / 2, Dir, true, true>::merge(newK, newV);
202 
203 #pragma unroll
204  for (int i = 0; i < N / 2; ++i) {
205  k[i] = newK[i];
206  v[i] = newV[i];
207  }
208  }
209 
210  {
211  K newK[N / 2];
212  V newV[N / 2];
213 
214 #pragma unroll
215  for (int i = 0; i < N / 2; ++i) {
216  newK[i] = k[i + N / 2];
217  newV[i] = v[i + N / 2];
218  }
219 
220  BitonicMergeStep<K, V, N / 2, Dir, false, true>::merge(newK, newV);
221 
222 #pragma unroll
223  for (int i = 0; i < N / 2; ++i) {
224  k[i + N / 2] = newK[i];
225  v[i + N / 2] = newV[i];
226  }
227  }
228  }
229 };
230 
231 //
232 // Non-power-of-2 merge specialization
233 //
234 
235 // Low recursion
236 template <typename K, typename V, int N, bool Dir>
237 struct BitonicMergeStep<K, V, N, Dir, true, false> {
238  static inline __device__ void merge(K k[N], V v[N]) {
239  static_assert(!isPowerOf2(N), "must be non-power-of-2");
240  static_assert(N >= 3, "must be N >= 3");
241 
242  constexpr int kNextHighestPowerOf2 = nextHighestPowerOf2(N);
243 
244 #pragma unroll
245  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
246  K& ka = k[i];
247  V& va = v[i];
248 
249  K& kb = k[i + kNextHighestPowerOf2 / 2];
250  V& vb = v[i + kNextHighestPowerOf2 / 2];
251 
252  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
253  bool s = Dir ? ka > kb : ka < kb;
254  swap(s, ka, kb);
255  swap(s, va, vb);
256  }
257 
258  constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
259  constexpr int kHighSize = kNextHighestPowerOf2 / 2;
260  {
261  K newK[kLowSize];
262  V newV[kLowSize];
263 
264 #pragma unroll
265  for (int i = 0; i < kLowSize; ++i) {
266  newK[i] = k[i];
267  newV[i] = v[i];
268  }
269 
270  constexpr bool kLowIsPowerOf2 =
271  isPowerOf2(N - kNextHighestPowerOf2 / 2);
272  // FIXME: compiler doesn't like this expression? compiler bug?
273  // constexpr bool kLowIsPowerOf2 = isPowerOf2(kLowSize);
274  BitonicMergeStep<K, V, kLowSize, Dir,
275  true, // low
276  kLowIsPowerOf2>::merge(newK, newV);
277 
278 #pragma unroll
279  for (int i = 0; i < kLowSize; ++i) {
280  k[i] = newK[i];
281  v[i] = newV[i];
282  }
283  }
284 
285  {
286  K newK[kHighSize];
287  V newV[kHighSize];
288 
289 #pragma unroll
290  for (int i = 0; i < kHighSize; ++i) {
291  newK[i] = k[i + kLowSize];
292  newV[i] = v[i + kLowSize];
293  }
294 
295  constexpr bool kHighIsPowerOf2 =
296  isPowerOf2(kNextHighestPowerOf2 / 2);
297  // FIXME: compiler doesn't like this expression? compiler bug?
298  // constexpr bool kHighIsPowerOf2 =
299  // isPowerOf2(kHighSize);
300  BitonicMergeStep<K, V, kHighSize, Dir,
301  false, // high
302  kHighIsPowerOf2>::merge(newK, newV);
303 
304 #pragma unroll
305  for (int i = 0; i < kHighSize; ++i) {
306  k[i + kLowSize] = newK[i];
307  v[i + kLowSize] = newV[i];
308  }
309  }
310  }
311 };
312 
313 // High recursion
314 template <typename K, typename V, int N, bool Dir>
315 struct BitonicMergeStep<K, V, N, Dir, false, false> {
316  static inline __device__ void merge(K k[N], V v[N]) {
317  static_assert(!isPowerOf2(N), "must be non-power-of-2");
318  static_assert(N >= 3, "must be N >= 3");
319 
320  constexpr int kNextHighestPowerOf2 = nextHighestPowerOf2(N);
321 
322 #pragma unroll
323  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
324  K& ka = k[i];
325  V& va = v[i];
326 
327  K& kb = k[i + kNextHighestPowerOf2 / 2];
328  V& vb = v[i + kNextHighestPowerOf2 / 2];
329 
330  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
331  bool s = Dir ? ka > kb : ka < kb;
332  swap(s, ka, kb);
333  swap(s, va, vb);
334  }
335 
336  constexpr int kLowSize = kNextHighestPowerOf2 / 2;
337  constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
338  {
339  K newK[kLowSize];
340  V newV[kLowSize];
341 
342 #pragma unroll
343  for (int i = 0; i < kLowSize; ++i) {
344  newK[i] = k[i];
345  newV[i] = v[i];
346  }
347 
348  constexpr bool kLowIsPowerOf2 =
349  isPowerOf2(kNextHighestPowerOf2 / 2);
350  // FIXME: compiler doesn't like this expression? compiler bug?
351  // constexpr bool kLowIsPowerOf2 = isPowerOf2(kLowSize);
352  BitonicMergeStep<K, V, kLowSize, Dir,
353  true, // low
354  kLowIsPowerOf2>::merge(newK, newV);
355 
356 #pragma unroll
357  for (int i = 0; i < kLowSize; ++i) {
358  k[i] = newK[i];
359  v[i] = newV[i];
360  }
361  }
362 
363  {
364  K newK[kHighSize];
365  V newV[kHighSize];
366 
367 #pragma unroll
368  for (int i = 0; i < kHighSize; ++i) {
369  newK[i] = k[i + kLowSize];
370  newV[i] = v[i + kLowSize];
371  }
372 
373  constexpr bool kHighIsPowerOf2 =
374  isPowerOf2(N - kNextHighestPowerOf2 / 2);
375  // FIXME: compiler doesn't like this expression? compiler bug?
376  // constexpr bool kHighIsPowerOf2 =
377  // isPowerOf2(kHighSize);
378  BitonicMergeStep<K, V, kHighSize, Dir,
379  false, // high
380  kHighIsPowerOf2>::merge(newK, newV);
381 
382 #pragma unroll
383  for (int i = 0; i < kHighSize; ++i) {
384  k[i + kLowSize] = newK[i];
385  v[i + kLowSize] = newV[i];
386  }
387  }
388  }
389 };
390 
391 /// Merges two sets of registers across the warp of any size;
392 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
393 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
394 /// value >= 1
395 template <typename K,
396  typename V,
397  int N1,
398  int N2,
399  bool Dir,
400  bool FullMerge = true>
401 inline __device__ void warpMergeAnyRegisters(K k1[N1],
402  V v1[N1],
403  K k2[N2],
404  V v2[N2]) {
405  constexpr int kSmallestN = N1 < N2 ? N1 : N2;
406 
407 #pragma unroll
408  for (int i = 0; i < kSmallestN; ++i) {
409  K& ka = k1[N1 - 1 - i];
410  V& va = v1[N1 - 1 - i];
411 
412  K& kb = k2[i];
413  V& vb = v2[i];
414 
415  K otherKa;
416  V otherVa;
417 
418  if (FullMerge) {
419  // We need the other values
420  otherKa = shfl_xor(ka, kWarpSize - 1);
421  otherVa = shfl_xor(va, kWarpSize - 1);
422  }
423 
424  K otherKb = shfl_xor(kb, kWarpSize - 1);
425  V otherVb = shfl_xor(vb, kWarpSize - 1);
426 
427  // ka is always first in the list, so we needn't use our lane
428  // in this comparison
429  // bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
430  bool swapa = Dir ? ka > otherKb : ka < otherKb;
431  assign(swapa, ka, otherKb);
432  assign(swapa, va, otherVb);
433 
434  // kb is always second in the list, so we needn't use our lane
435  // in this comparison
436  if (FullMerge) {
437  // bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
438  bool swapb = Dir ? kb < otherKa : kb > otherKa;
439  assign(swapb, kb, otherKa);
440  assign(swapb, vb, otherVa);
441 
442  } else {
443  // We don't care about updating elements in the second list
444  }
445  }
446 
447  BitonicMergeStep<K, V, N1, Dir, true, isPowerOf2(N1)>::merge(k1, v1);
448  if (FullMerge) {
449  // Only if we care about N2 do we need to bother merging it fully
450  BitonicMergeStep<K, V, N2, Dir, false, isPowerOf2(N2)>::merge(k2, v2);
451  }
452 }
453 
454 // Recursive template that uses the above bitonic merge to perform a
455 // bitonic sort
456 template <typename K, typename V, int N, bool Dir>
457 struct BitonicSortStep {
458  static inline __device__ void sort(K k[N], V v[N]) {
459  static_assert(N > 1, "did not hit specialized case");
460 
461  // Sort recursively
462  constexpr int kSizeA = N / 2;
463  constexpr int kSizeB = N - kSizeA;
464 
465  K aK[kSizeA];
466  V aV[kSizeA];
467 
468 #pragma unroll
469  for (int i = 0; i < kSizeA; ++i) {
470  aK[i] = k[i];
471  aV[i] = v[i];
472  }
473 
474  BitonicSortStep<K, V, kSizeA, Dir>::sort(aK, aV);
475 
476  K bK[kSizeB];
477  V bV[kSizeB];
478 
479 #pragma unroll
480  for (int i = 0; i < kSizeB; ++i) {
481  bK[i] = k[i + kSizeA];
482  bV[i] = v[i + kSizeA];
483  }
484 
485  BitonicSortStep<K, V, kSizeB, Dir>::sort(bK, bV);
486 
487  // Merge halves
488  warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir>(aK, aV, bK, bV);
489 
490 #pragma unroll
491  for (int i = 0; i < kSizeA; ++i) {
492  k[i] = aK[i];
493  v[i] = aV[i];
494  }
495 
496 #pragma unroll
497  for (int i = 0; i < kSizeB; ++i) {
498  k[i + kSizeA] = bK[i];
499  v[i + kSizeA] = bV[i];
500  }
501  }
502 };
503 
504 // Single warp (N == 1) sorting specialization
505 template <typename K, typename V, bool Dir>
506 struct BitonicSortStep<K, V, 1, Dir> {
507  static inline __device__ void sort(K k[1], V v[1]) {
508  // Update this code if this changes
509  // should go from 1 -> kWarpSize in multiples of 2
510  static_assert(kWarpSize == 32, "unexpected warp size");
511 
512  warpBitonicMergeLE16<K, V, 1, Dir, false>(k[0], v[0]);
513  warpBitonicMergeLE16<K, V, 2, Dir, false>(k[0], v[0]);
514  warpBitonicMergeLE16<K, V, 4, Dir, false>(k[0], v[0]);
515  warpBitonicMergeLE16<K, V, 8, Dir, false>(k[0], v[0]);
516  warpBitonicMergeLE16<K, V, 16, Dir, false>(k[0], v[0]);
517  }
518 };
519 
520 /// Sort a list of kWarpSize * N elements in registers, where N is an
521 /// arbitrary >= 1
522 template <typename K, typename V, int N, bool Dir>
523 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
524  BitonicSortStep<K, V, N, Dir>::sort(k, v);
525 }
526 
527 } // namespace core
528 } // namespace cloudViewer