1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
12 #include "core/nns/kernel/PtxUtils.cuh"
13 #include "core/nns/kernel/WarpShuffle.cuh"
15 namespace cloudViewer {
18 /// A simple pair type for CUDA device usage
19 template <typename K, typename V>
21 constexpr __device__ inline Pair() {}
23 constexpr __device__ inline Pair(K key, V value) : k(key), v(value) {}
25 __device__ inline bool operator==(const Pair<K, V>& rhs) const {
26 return k == rhs.k && v == rhs.v;
29 __device__ inline bool operator!=(const Pair<K, V>& rhs) const {
30 return !operator==(rhs);
33 __device__ inline bool operator<(const Pair<K, V>& rhs) const {
34 return k < rhs.k || (k == rhs.k && v < rhs.v);
37 __device__ inline bool operator>(const Pair<K, V>& rhs) const {
38 return k > rhs.k || (k == rhs.k && v > rhs.v);
45 template <typename T, typename U>
46 inline __device__ Pair<T, U> shfl_up(const Pair<T, U>& pair,
48 int width = kWarpSize) {
49 return Pair<T, U>(shfl_up(pair.k, delta, width),
50 shfl_up(pair.v, delta, width));
53 template <typename T, typename U>
54 inline __device__ Pair<T, U> shfl_xor(const Pair<T, U>& pair,
56 int width = kWarpSize) {
57 return Pair<T, U>(shfl_xor(pair.k, laneMask, width),
58 shfl_xor(pair.v, laneMask, width));
62 } // namespace cloudViewer