1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include "cloudViewer/core/CUDAUtils.h"
13 #include "core/Tensor.h"
14 #include "core/nns/kernel/Limits.cuh"
15 #include "core/nns/kernel/Pair.cuh"
16 #include "core/nns/kernel/Reduction.cuh"
17 #include "core/nns/kernel/ReductionOps.cuh"
18 #include "core/nns/kernel/Select.cuh"
20 namespace cloudViewer {
24 // L2 + select kernel for k == 1, implements re-use of ||c||^2
25 template <typename T, typename TIndex, int kRowsPerBlock, int kBlockSize>
26 __global__ void l2SelectMin1(T* productDistances,
32 // Each block handles kRowsPerBlock rows of the distances (results)
33 Pair<T, int> threadMin[kRowsPerBlock];
34 __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
36 T distance[kRowsPerBlock];
39 for (int i = 0; i < kRowsPerBlock; ++i) {
40 threadMin[i].k = Limits<T>::getMax();
44 // blockIdx.x: which chunk of rows we are responsible for updating
45 int rowStart = blockIdx.x * kRowsPerBlock;
47 // FIXME: if we have exact multiples, don't need this
48 bool endRow = (blockIdx.x == gridDim.x - 1);
51 if (num_points % kRowsPerBlock == 0) {
57 for (int row = rowStart; row < num_points; ++row) {
58 for (int col = threadIdx.x; col < dim; col += blockDim.x) {
59 distance[0] = centroidDistances[col] +
60 productDistances[row + dim + col];
62 if (distance[0] < threadMin[0].k) {
63 threadMin[0].k = distance[0];
68 // Reduce within the block
69 threadMin[0] = blockReduceAll<Pair<T, int>, Min<Pair<T, int>>,
71 threadMin[0], Min<Pair<T, int>>(), blockMin);
73 if (threadIdx.x == 0) {
74 outDistances[row + 0] = threadMin[0].k;
75 outIndices[row + 0] = threadMin[0].v;
78 // so we can use the shared memory again
81 threadMin[0].k = Limits<T>::getMax();
85 for (int col = threadIdx.x; col < dim; col += blockDim.x) {
86 T centroidDistance = centroidDistances[col];
89 for (int row = 0; row < kRowsPerBlock; ++row) {
90 distance[row] = productDistances[(rowStart + row) * dim + col];
94 for (int row = 0; row < kRowsPerBlock; ++row) {
95 distance[row] = distance[row] + centroidDistance;
99 for (int row = 0; row < kRowsPerBlock; ++row) {
100 if (distance[row] < threadMin[row].k) {
101 threadMin[row].k = distance[row];
102 threadMin[row].v = col;
107 // Reduce within the block
108 blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int>>, false,
109 false>(threadMin, Min<Pair<T, int>>(), blockMin);
111 if (threadIdx.x == 0) {
113 for (int row = 0; row < kRowsPerBlock; ++row) {
114 outDistances[rowStart + row + 0] = threadMin[row].k;
115 outIndices[rowStart + row + 0] = threadMin[row].v;
121 // L2 + select kernel for k > 1, no re-use of ||c||^2
122 template <typename T,
127 __global__ void l2SelectMinK(T* productDistances,
128 T* centroidDistances,
136 // Each block handles a single row of the distances (results)
137 constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
139 __shared__ T smemK[kNumWarps * NumWarpQ];
140 __shared__ int smemV[kNumWarps * NumWarpQ];
142 BlockSelect<T, int, false, NumWarpQ, NumThreadQ, ThreadsPerBlock> heap(
143 initK, -1, smemK, smemV, k);
145 int row = blockIdx.x;
147 // Whole warps must participate in the selection
148 // int limit = utils::roundDown(dim, kWarpSize);
149 int limit = (dim / kWarpSize) * kWarpSize;
152 for (; i < limit; i += blockDim.x) {
153 T v = centroidDistances[i] + productDistances[row * tile_cols + i];
158 T v = centroidDistances[i] + productDistances[row * tile_cols + i];
159 heap.addThreadQ(v, i);
163 for (int i = threadIdx.x; i < k; i += blockDim.x) {
164 outDistances[row * k * num_cols + i] = smemK[i];
165 outIndices[row * k * num_cols + i] = smemV[i];
169 template <typename T, typename TIndex>
170 void runL2SelectMin(const cudaStream_t stream,
171 Tensor& productDistances,
172 Tensor& centroidDistances,
173 Tensor& outDistances,
178 CLOUDVIEWER_ASSERT(productDistances.GetShape(0) ==
179 outDistances.GetShape(0));
180 CLOUDVIEWER_ASSERT(productDistances.GetShape(0) == outIndices.GetShape(0));
181 CLOUDVIEWER_ASSERT(centroidDistances.GetShape(0) ==
182 productDistances.GetShape(1));
183 CLOUDVIEWER_ASSERT(outDistances.GetShape(1) == k);
184 CLOUDVIEWER_ASSERT(outIndices.GetShape(1) == k);
185 CLOUDVIEWER_ASSERT(k <= GPU_MAX_SELECTION_K);
188 constexpr int kThreadsPerBlock = 256;
189 constexpr int kRowsPerBlock = 8;
191 auto block = dim3(kThreadsPerBlock);
193 dim3(utility::DivUp(outDistances.GetShape(0), kRowsPerBlock));
195 l2SelectMin1<T, TIndex, kRowsPerBlock, kThreadsPerBlock>
196 <<<grid, block, 0, stream>>>(productDistances.GetDataPtr<T>(),
197 centroidDistances.GetDataPtr<T>(),
198 outDistances.GetDataPtr<T>(),
199 outIndices.GetDataPtr<TIndex>(),
200 (int)productDistances.GetShape(0),
201 (int)productDistances.GetShape(1));
203 auto grid = dim3(outDistances.GetShape(0));
205 #define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
207 l2SelectMinK<T, TIndex, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
208 <<<grid, BLOCK, 0, stream>>>( \
209 productDistances.GetDataPtr<T>(), \
210 centroidDistances.GetDataPtr<T>(), \
211 outDistances.GetDataPtr<T>(), \
212 outIndices.GetDataPtr<TIndex>(), k, \
213 productDistances.GetShape(1), num_cols, tile_cols, \
214 Limits<T>::getMax()); \
217 // block size 128 for everything <= 1024
219 RUN_L2_SELECT(128, 32, 2);
220 } else if (k <= 64) {
221 RUN_L2_SELECT(128, 64, 3);
222 } else if (k <= 128) {
223 RUN_L2_SELECT(128, 128, 3);
224 } else if (k <= 256) {
225 RUN_L2_SELECT(128, 256, 4);
226 } else if (k <= 512) {
227 RUN_L2_SELECT(128, 512, 8);
228 } else if (k <= 1024) {
229 RUN_L2_SELECT(128, 1024, 8);
231 #if GPU_MAX_SELECTION_K >= 2048
232 } else if (k <= 2048) {
233 // smaller block for less shared memory
234 RUN_L2_SELECT(64, 2048, 8);
238 CLOUDVIEWER_ASSERT(false);
242 // CUDA_TEST_ERROR();
247 } // namespace cloudViewer