1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "core/Tensor.h"
11 #include "core/nns/kernel/Select.cuh"
13 namespace cloudViewer {
22 __global__ void blockSelect(K* in,
30 constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
32 __shared__ K smemK[kNumWarps * NumWarpQ];
33 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
35 BlockSelect<K, IndexType, Dir, NumWarpQ, NumThreadQ, ThreadsPerBlock> heap(
36 initK, initV, smemK, smemV, k);
38 // Grid is exactly sized to rows available
42 K* inStart = in + dim * row + i;
44 // Whole warps must participate in the selection
45 int limit = (dim / kWarpSize) * kWarpSize;
47 for (; i < limit; i += ThreadsPerBlock) {
48 heap.add(*inStart, (IndexType)i);
49 inStart += ThreadsPerBlock;
52 // Handle last remainder fraction of a warp of elements
54 heap.addThreadQ(*inStart, (IndexType)i);
59 for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
60 *(outK + row * dim + i) = smemK[i];
61 *(outV + row * dim + i) = smemV[i];
71 __global__ void blockSelectPair(K* inK,
80 constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
82 __shared__ K smemK[kNumWarps * NumWarpQ];
83 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
85 BlockSelect<K, IndexType, Dir, NumWarpQ, NumThreadQ, ThreadsPerBlock> heap(
86 initK, initV, smemK, smemV, k);
88 // Grid is exactly sized to rows available
92 K* inKStart = &inK[row * dim + i];
93 IndexType* inVStart = &inV[row * dim + i];
95 // Whole warps must participate in the selection
96 int limit = (dim / kWarpSize) * kWarpSize;
98 for (; i < limit; i += ThreadsPerBlock) {
99 heap.add(*inKStart, *inVStart);
100 inKStart += ThreadsPerBlock;
101 inVStart += ThreadsPerBlock;
104 // Handle last remainder fraction of a warp of elements
106 heap.addThreadQ(*inKStart, *inVStart);
111 for (int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
112 outK[row * k + i] = smemK[i];
113 outV[row * k + i] = smemV[i];
117 void runBlockSelectPair(cudaStream_t stream,
127 void runBlockSelectPair(cudaStream_t stream,
137 void runBlockSelectPair(cudaStream_t stream,
147 void runBlockSelectPair(cudaStream_t stream,
158 } // namespace cloudViewer