1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
10 #include "cloudViewer/core/CUDAUtils.h"
11 #include "core/nns/kernel/BlockSelect.cuh"
12 #include "core/nns/kernel/Limits.cuh"
14 #define BLOCK_SELECT_IMPL(TYPE, TINDEX, DIR, WARP_Q, THREAD_Q) \
15 void runBlockSelect_##TYPE##_##TINDEX##_##DIR##_##WARP_Q##_( \
16 cudaStream_t stream, TYPE* in, TYPE* outK, TINDEX* outV, bool dir, \
17 int k, int dim, int num_points) { \
18 auto grid = dim3(num_points); \
20 constexpr int kBlockSelectNumThreads = \
21 sizeof(TYPE) == 4 ? ((WARP_Q <= 1024) ? 128 : 64) \
22 : ((WARP_Q <= 512) ? 64 : 32); \
23 auto block = dim3(kBlockSelectNumThreads); \
25 CLOUDVIEWER_ASSERT(k <= WARP_Q); \
26 CLOUDVIEWER_ASSERT(dir == DIR); \
28 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
31 blockSelect<TYPE, TINDEX, DIR, WARP_Q, THREAD_Q, \
32 kBlockSelectNumThreads><<<grid, block, 0, stream>>>( \
33 in, outK, outV, kInit, vInit, k, dim, num_points); \
36 void runBlockSelectPair_##TYPE##_##TINDEX##_##DIR##_##WARP_Q##_( \
37 cudaStream_t stream, TYPE* inK, TINDEX* inV, TYPE* outK, \
38 TINDEX* outV, bool dir, int k, int dim, int num_points) { \
39 auto grid = dim3(num_points); \
41 constexpr int kBlockSelectNumThreads = \
42 sizeof(TYPE) == 4 ? ((WARP_Q <= 1024) ? 128 : 64) \
43 : ((WARP_Q <= 512) ? 64 : 32); \
44 auto block = dim3(kBlockSelectNumThreads); \
46 CLOUDVIEWER_ASSERT(k <= WARP_Q); \
47 CLOUDVIEWER_ASSERT(dir == DIR); \
49 auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
52 blockSelectPair<TYPE, TINDEX, DIR, WARP_Q, THREAD_Q, \
53 kBlockSelectNumThreads><<<grid, block, 0, stream>>>( \
54 inK, inV, outK, outV, kInit, vInit, k, dim, num_points); \
57 #define BLOCK_SELECT_CALL(TYPE, TINDEX, DIR, WARP_Q) \
58 runBlockSelect_##TYPE##_##TINDEX##_##DIR##_##WARP_Q##_( \
59 stream, in, outK, outV, dir, k, dim, num_points)
61 #define BLOCK_SELECT_PAIR_CALL(TYPE, TINDEX, DIR, WARP_Q) \
62 runBlockSelectPair_##TYPE##_##TINDEX##_##DIR##_##WARP_Q##_( \
63 stream, inK, inV, outK, outV, dir, k, dim, num_points)