1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
9 #include <thrust/device_vector.h>
10 #include <thrust/iterator/counting_iterator.h>
11 #include <thrust/sequence.h>
12 #include <thrust/sort.h>
14 #include "ml/Helper.h"
15 #include "ml/contrib/IoUImpl.h"
16 #include "ml/contrib/Nms.h"
18 namespace cloudViewer {
23 static void SortIndices(T *values,
24 int64_t *sort_indices,
26 bool descending = false) {
27 // Cast to thrust device pointer.
28 thrust::device_ptr<T> values_dptr = thrust::device_pointer_cast(values);
29 thrust::device_ptr<int64_t> sort_indices_dptr =
30 thrust::device_pointer_cast(sort_indices);
32 // Fill sort_indices with 0, 1, ..., n-1.
33 thrust::sequence(sort_indices_dptr, sort_indices_dptr + n, 0);
35 // Sort values and sort_indices together.
37 thrust::stable_sort_by_key(values_dptr, values_dptr + n,
38 sort_indices_dptr, thrust::greater<T>());
40 thrust::stable_sort_by_key(values_dptr, values_dptr + n,
45 __global__ void NmsKernel(const float *boxes,
46 const int64_t *sort_indices,
49 const double nms_overlap_thresh,
50 const int num_block_cols) {
51 // Row-wise block index.
52 const int block_row_idx = blockIdx.y;
53 // Column-wise block index.
54 const int block_col_idx = blockIdx.x;
56 // Local block row size.
58 fminf(n - block_row_idx * NMS_BLOCK_SIZE, NMS_BLOCK_SIZE);
59 // Local block col size.
61 fminf(n - block_col_idx * NMS_BLOCK_SIZE, NMS_BLOCK_SIZE);
63 // Fill local block_boxes by fetching the global box memory.
64 // block_boxes = boxes[NBS*block_col_idx : NBS*block_col_idx+col_size, :].
66 // TODO: It is also possible to load the comparison target to the shared
68 __shared__ float block_boxes[NMS_BLOCK_SIZE * 5];
69 if (threadIdx.x < col_size) {
70 float *dst = block_boxes + threadIdx.x * 5;
71 const int src_idx = NMS_BLOCK_SIZE * block_col_idx + threadIdx.x;
72 const float *src = boxes + sort_indices[src_idx] * 5;
81 // Comparing src and dst. In one block, the following src and dst indices
83 // - src: BS * block_row_idx : BS * block_row_idx + row_size
84 // - dst: BS * block_col_idx : BS * block_col_idx + col_size
86 // With all blocks, all src and dst indices are compared.
89 // mask[i, j] is a 64-bit integer where mask[i, j][k] (k counted from right)
90 // is 1 iff box[i] overlaps with box[BS*j+k].
91 if (threadIdx.x < row_size) {
92 // src_idx indices the global memory.
93 const int src_idx = NMS_BLOCK_SIZE * block_row_idx + threadIdx.x;
94 // dst_idx indices the shared memory.
95 int dst_idx = block_row_idx == block_col_idx ? threadIdx.x + 1 : 0;
98 while (dst_idx < col_size) {
99 if (IoUBev2DWithMinAndMax(boxes + sort_indices[src_idx] * 5,
100 block_boxes + dst_idx * 5) >
101 nms_overlap_thresh) {
102 t |= 1ULL << dst_idx;
106 mask[src_idx * num_block_cols + block_col_idx] = t;
110 std::vector<int64_t> NmsCUDAKernel(const float *boxes,
113 double nms_overlap_thresh) {
118 // Cololum-wise number of blocks.
119 const int num_block_cols = cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE);
121 // Compute sort indices.
122 float *scores_copy = nullptr;
123 CLOUDVIEWER_CUDA_CHECK(
124 cudaMalloc((void **)&scores_copy, n * sizeof(float)));
125 CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(scores_copy, scores, n * sizeof(float),
126 cudaMemcpyDeviceToDevice));
127 int64_t *sort_indices = nullptr;
128 CLOUDVIEWER_CUDA_CHECK(
129 cudaMalloc((void **)&sort_indices, n * sizeof(int64_t)));
130 SortIndices(scores_copy, sort_indices, n, true);
131 CLOUDVIEWER_CUDA_CHECK(cudaFree(scores_copy));
133 // Allocate masks on device.
134 uint64_t *mask_ptr = nullptr;
135 CLOUDVIEWER_CUDA_CHECK(cudaMalloc((void **)&mask_ptr,
136 n * num_block_cols * sizeof(uint64_t)));
139 dim3 blocks(cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE),
140 cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE));
141 dim3 threads(NMS_BLOCK_SIZE);
142 NmsKernel<<<blocks, threads>>>(boxes, sort_indices, mask_ptr, n,
143 nms_overlap_thresh, num_block_cols);
145 // Copy cuda masks to cpu.
146 std::vector<uint64_t> mask_vec(n * num_block_cols);
147 uint64_t *mask = mask_vec.data();
148 CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(mask_vec.data(), mask_ptr,
149 n * num_block_cols * sizeof(uint64_t),
150 cudaMemcpyDeviceToHost));
151 CLOUDVIEWER_CUDA_CHECK(cudaFree(mask_ptr));
153 // Copy sort_indices to cpu.
154 std::vector<int64_t> sort_indices_cpu(n);
155 CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(sort_indices_cpu.data(), sort_indices,
157 cudaMemcpyDeviceToHost));
159 // Write to keep_indices in CPU.
160 // remv_cpu has n bits in total. If the bit is 1, the corresponding
161 // box will be removed.
162 // TODO: This part can be implemented in CUDA. We use the original author's
163 // implementation here.
164 std::vector<uint64_t> remv_cpu(num_block_cols, 0);
165 std::vector<int64_t> keep_indices;
166 for (int i = 0; i < n; i++) {
167 int block_col_idx = i / NMS_BLOCK_SIZE;
168 int inner_block_col_idx = i % NMS_BLOCK_SIZE; // threadIdx.x
170 // Querying the i-th bit in remv_cpu, counted from the right.
171 // - remv_cpu[block_col_idx]: the block bitmap containing the query
172 // - 1ULL << inner_block_col_idx: the one-hot bitmap to extract i
173 if (!(remv_cpu[block_col_idx] & (1ULL << inner_block_col_idx))) {
174 // Keep the i-th box.
175 keep_indices.push_back(sort_indices_cpu[i]);
177 // Any box that overlaps with the i-th box will be removed.
178 uint64_t *p = mask + i * num_block_cols;
179 for (int j = block_col_idx; j < num_block_cols; j++) {
184 CLOUDVIEWER_CUDA_CHECK(cudaFree(sort_indices));
188 } // namespace contrib
190 } // namespace cloudViewer