1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
13 #include "core/nns/kernel/BlockMerge.cuh"
14 #include "core/nns/kernel/PtxUtils.cuh"
15 #include "core/nns/kernel/ReductionOps.cuh"
17 namespace cloudViewer {
20 template <typename T, typename Op, int ReduceWidth = kWarpSize>
21 __device__ inline T warpReduceAll(T val, Op op) {
23 for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
24 val = op(val, shfl_xor(val, mask));
30 /// Sums a register value across all warp threads
31 template <typename T, int ReduceWidth = kWarpSize>
32 __device__ inline T warpReduceAllSum(T val) {
33 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
36 /// Performs a block-wide reduction
37 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
38 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
39 int laneId = getLaneId();
40 int warpId = threadIdx.x / kWarpSize;
42 val = warpReduceAll<T, Op>(val, op);
49 val = laneId < divUp(blockDim.x, kWarpSize) ? smem[laneId]
51 val = warpReduceAll<T, Op>(val, op);
54 __threadfence_block();
67 if (KillWARDependency) {
74 /// Performs a block-wide reduction of multiple values simultaneously
79 bool KillWARDependency>
80 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
81 int laneId = getLaneId();
82 int warpId = threadIdx.x / kWarpSize;
85 for (int i = 0; i < Num; ++i) {
86 val[i] = warpReduceAll<T, Op>(val[i], op);
91 for (int i = 0; i < Num; ++i) {
92 smem[warpId * Num + i] = val[i];
100 for (int i = 0; i < Num; ++i) {
101 val[i] = laneId < divUp(blockDim.x, kWarpSize)
102 ? smem[laneId * Num + i]
104 val[i] = warpReduceAll<T, Op>(val[i], op);
108 __threadfence_block();
112 for (int i = 0; i < Num; ++i) {
122 for (int i = 0; i < Num; ++i) {
127 if (KillWARDependency) {
132 /// Sums a register value across the entire block
133 template <typename T, bool BroadcastAll, bool KillWARDependency>
134 __device__ inline T blockReduceAllSum(T val, T* smem) {
135 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
136 val, Sum<T>(), smem);
139 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
140 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
141 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
142 vals, Sum<T>(), smem);
146 } // namespace cloudViewer