1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "core/nns/kernel/DeviceDefs.cuh"
11 #include "core/nns/kernel/StaticUtils.cuh"
13 namespace cloudViewer {
16 // Merge pairs of lists smaller than blockDim.x (NumThreads)
17 template <int NumThreads,
25 inline __device__ void blockMergeSmall(K* listK, V* listV) {
26 static_assert(isPowerOf2(L), "L must be a power-of-2");
27 static_assert(isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
28 static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
30 // Which pair of lists we are merging
31 int mergeId = threadIdx.x / L;
33 // Which thread we are within the merge
34 int tid = threadIdx.x % L;
36 // listK points to a region of size N * 2 * L
37 listK += 2 * L * mergeId;
38 listV += 2 * L * mergeId;
40 // It's not a bitonic merge, both lists are in the same direction,
41 // so handle the first swap assuming the second list is reversed
42 int pos = L - 1 - tid;
43 int stride = 2 * tid + 1;
45 if (AllThreads || (threadIdx.x < N * L)) {
47 K kb = listK[pos + stride];
49 // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
50 bool swap = Dir ? ka > kb : ka < kb;
51 listK[pos] = swap ? kb : ka;
52 listK[pos + stride] = swap ? ka : kb;
55 V vb = listV[pos + stride];
56 listV[pos] = swap ? vb : va;
57 listV[pos + stride] = swap ? va : vb;
59 // FIXME: is this a CUDA 9 compiler bug?
60 // K& ka = listK[pos];
61 // K& kb = listK[pos + stride];
63 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
66 // V& va = listV[pos];
67 // V& vb = listV[pos + stride];
74 for (int stride = L / 2; stride > 0; stride /= 2) {
75 int pos = 2 * tid - (tid & (stride - 1));
77 if (AllThreads || (threadIdx.x < N * L)) {
79 K kb = listK[pos + stride];
81 // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
82 bool swap = Dir ? ka > kb : ka < kb;
83 listK[pos] = swap ? kb : ka;
84 listK[pos + stride] = swap ? ka : kb;
87 V vb = listV[pos + stride];
88 listV[pos] = swap ? vb : va;
89 listV[pos + stride] = swap ? va : vb;
91 // FIXME: is this a CUDA 9 compiler bug?
92 // K& ka = listK[pos];
93 // K& kb = listK[pos + stride];
95 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
98 // V& va = listV[pos];
99 // V& vb = listV[pos + stride];
107 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
108 template <int NumThreads,
114 inline __device__ void blockMergeLarge(K* listK, V* listV) {
115 static_assert(isPowerOf2(L), "L must be a power-of-2");
116 static_assert(L >= kWarpSize, "merge list size must be >= 32");
117 static_assert(isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
118 static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
120 // For L > NumThreads, each thread has to perform more work
122 constexpr int kLoopPerThread = L / NumThreads;
124 // It's not a bitonic merge, both lists are in the same direction,
125 // so handle the first swap assuming the second list is reversed
127 for (int loop = 0; loop < kLoopPerThread; ++loop) {
128 int tid = loop * NumThreads + threadIdx.x;
129 int pos = L - 1 - tid;
130 int stride = 2 * tid + 1;
133 K kb = listK[pos + stride];
135 // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
136 bool swap = Dir ? ka > kb : ka < kb;
137 listK[pos] = swap ? kb : ka;
138 listK[pos + stride] = swap ? ka : kb;
141 V vb = listV[pos + stride];
142 listV[pos] = swap ? vb : va;
143 listV[pos + stride] = swap ? va : vb;
145 // FIXME: is this a CUDA 9 compiler bug?
146 // K& ka = listK[pos];
147 // K& kb = listK[pos + stride];
149 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
152 // V& va = listV[pos];
153 // V& vb = listV[pos + stride];
159 constexpr int kSecondLoopPerThread =
160 FullMerge ? kLoopPerThread : kLoopPerThread / 2;
163 for (int stride = L / 2; stride > 0; stride /= 2) {
165 for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
166 int tid = loop * NumThreads + threadIdx.x;
167 int pos = 2 * tid - (tid & (stride - 1));
170 K kb = listK[pos + stride];
172 // bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
173 bool swap = Dir ? ka > kb : ka < kb;
174 listK[pos] = swap ? kb : ka;
175 listK[pos + stride] = swap ? ka : kb;
178 V vb = listV[pos + stride];
179 listV[pos] = swap ? vb : va;
180 listV[pos + stride] = swap ? va : vb;
182 // FIXME: is this a CUDA 9 compiler bug?
183 // K& ka = listK[pos];
184 // K& kb = listK[pos + stride];
186 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
189 // V& va = listV[pos];
190 // V& vb = listV[pos + stride];
198 /// Class template to prevent static_assert from firing for
199 /// mixing smaller/larger than block cases
200 template <int NumThreads,
206 bool SmallerThanBlock,
208 struct BlockMerge {};
210 /// Merging lists smaller than a block
211 template <int NumThreads,
218 struct BlockMerge<NumThreads, K, V, N, L, Dir, true, FullMerge> {
219 static inline __device__ void merge(K* listK, V* listV) {
220 constexpr int kNumParallelMerges = NumThreads / L;
221 constexpr int kNumIterations = N / kNumParallelMerges;
223 static_assert(L <= NumThreads, "list must be <= NumThreads");
224 static_assert((N < kNumParallelMerges) ||
225 (kNumIterations * kNumParallelMerges == N),
226 "improper selection of N and L");
228 if (N < kNumParallelMerges) {
229 // We only need L threads per each list to perform the merge
230 blockMergeSmall<NumThreads, K, V, N, L, false, Dir, FullMerge>(
233 // All threads participate
235 for (int i = 0; i < kNumIterations; ++i) {
236 int start = i * kNumParallelMerges * 2 * L;
238 blockMergeSmall<NumThreads, K, V, N, L, true, Dir, FullMerge>(
239 listK + start, listV + start);
245 /// Merging lists larger than a block
246 template <int NumThreads,
253 struct BlockMerge<NumThreads, K, V, N, L, Dir, false, FullMerge> {
254 static inline __device__ void merge(K* listK, V* listV) {
255 // Each pair of lists is merged sequentially
257 for (int i = 0; i < N; ++i) {
258 int start = i * 2 * L;
260 blockMergeLarge<NumThreads, K, V, L, Dir, FullMerge>(listK + start,
266 template <int NumThreads,
272 bool FullMerge = true>
273 inline __device__ void blockMerge(K* listK, V* listV) {
274 constexpr bool kSmallerThanBlock = (L <= NumThreads);
276 BlockMerge<NumThreads, K, V, N, L, Dir, kSmallerThanBlock,
277 FullMerge>::merge(listK, listV);
281 } // namespace cloudViewer