ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
Nms.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <Helper.h>
9 #include <thrust/device_vector.h>
10 #include <thrust/iterator/counting_iterator.h>
11 #include <thrust/sequence.h>
12 #include <thrust/sort.h>
13 
14 #include "ml/Helper.h"
15 #include "ml/contrib/IoUImpl.h"
16 #include "ml/contrib/Nms.h"
17 
18 namespace cloudViewer {
19 namespace ml {
20 namespace contrib {
21 
22 template <typename T>
23 static void SortIndices(T *values,
24  int64_t *sort_indices,
25  int64_t n,
26  bool descending = false) {
27  // Cast to thrust device pointer.
28  thrust::device_ptr<T> values_dptr = thrust::device_pointer_cast(values);
29  thrust::device_ptr<int64_t> sort_indices_dptr =
30  thrust::device_pointer_cast(sort_indices);
31 
32  // Fill sort_indices with 0, 1, ..., n-1.
33  thrust::sequence(sort_indices_dptr, sort_indices_dptr + n, 0);
34 
35  // Sort values and sort_indices together.
36  if (descending) {
37  thrust::stable_sort_by_key(values_dptr, values_dptr + n,
38  sort_indices_dptr, thrust::greater<T>());
39  } else {
40  thrust::stable_sort_by_key(values_dptr, values_dptr + n,
41  sort_indices_dptr);
42  }
43 }
44 
45 __global__ void NmsKernel(const float *boxes,
46  const int64_t *sort_indices,
47  uint64_t *mask,
48  const int n,
49  const double nms_overlap_thresh,
50  const int num_block_cols) {
51  // Row-wise block index.
52  const int block_row_idx = blockIdx.y;
53  // Column-wise block index.
54  const int block_col_idx = blockIdx.x;
55 
56  // Local block row size.
57  const int row_size =
58  fminf(n - block_row_idx * NMS_BLOCK_SIZE, NMS_BLOCK_SIZE);
59  // Local block col size.
60  const int col_size =
61  fminf(n - block_col_idx * NMS_BLOCK_SIZE, NMS_BLOCK_SIZE);
62 
63  // Fill local block_boxes by fetching the global box memory.
64  // block_boxes = boxes[NBS*block_col_idx : NBS*block_col_idx+col_size, :].
65  //
66  // TODO: It is also possible to load the comparison target to the shared
67  // memory as well.
68  __shared__ float block_boxes[NMS_BLOCK_SIZE * 5];
69  if (threadIdx.x < col_size) {
70  float *dst = block_boxes + threadIdx.x * 5;
71  const int src_idx = NMS_BLOCK_SIZE * block_col_idx + threadIdx.x;
72  const float *src = boxes + sort_indices[src_idx] * 5;
73  dst[0] = src[0];
74  dst[1] = src[1];
75  dst[2] = src[2];
76  dst[3] = src[3];
77  dst[4] = src[4];
78  }
79  __syncthreads();
80 
81  // Comparing src and dst. In one block, the following src and dst indices
82  // are compared:
83  // - src: BS * block_row_idx : BS * block_row_idx + row_size
84  // - dst: BS * block_col_idx : BS * block_col_idx + col_size
85  //
86  // With all blocks, all src and dst indices are compared.
87  //
88  // Result:
89  // mask[i, j] is a 64-bit integer where mask[i, j][k] (k counted from right)
90  // is 1 iff box[i] overlaps with box[BS*j+k].
91  if (threadIdx.x < row_size) {
92  // src_idx indices the global memory.
93  const int src_idx = NMS_BLOCK_SIZE * block_row_idx + threadIdx.x;
94  // dst_idx indices the shared memory.
95  int dst_idx = block_row_idx == block_col_idx ? threadIdx.x + 1 : 0;
96 
97  uint64_t t = 0;
98  while (dst_idx < col_size) {
99  if (IoUBev2DWithMinAndMax(boxes + sort_indices[src_idx] * 5,
100  block_boxes + dst_idx * 5) >
101  nms_overlap_thresh) {
102  t |= 1ULL << dst_idx;
103  }
104  dst_idx++;
105  }
106  mask[src_idx * num_block_cols + block_col_idx] = t;
107  }
108 }
109 
110 std::vector<int64_t> NmsCUDAKernel(const float *boxes,
111  const float *scores,
112  int n,
113  double nms_overlap_thresh) {
114  if (n == 0) {
115  return {};
116  }
117 
118  // Cololum-wise number of blocks.
119  const int num_block_cols = cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE);
120 
121  // Compute sort indices.
122  float *scores_copy = nullptr;
123  CLOUDVIEWER_CUDA_CHECK(
124  cudaMalloc((void **)&scores_copy, n * sizeof(float)));
125  CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(scores_copy, scores, n * sizeof(float),
126  cudaMemcpyDeviceToDevice));
127  int64_t *sort_indices = nullptr;
128  CLOUDVIEWER_CUDA_CHECK(
129  cudaMalloc((void **)&sort_indices, n * sizeof(int64_t)));
130  SortIndices(scores_copy, sort_indices, n, true);
131  CLOUDVIEWER_CUDA_CHECK(cudaFree(scores_copy));
132 
133  // Allocate masks on device.
134  uint64_t *mask_ptr = nullptr;
135  CLOUDVIEWER_CUDA_CHECK(cudaMalloc((void **)&mask_ptr,
136  n * num_block_cols * sizeof(uint64_t)));
137 
138  // Launch kernel.
139  dim3 blocks(cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE),
140  cloudViewer::utility::DivUp(n, NMS_BLOCK_SIZE));
141  dim3 threads(NMS_BLOCK_SIZE);
142  NmsKernel<<<blocks, threads>>>(boxes, sort_indices, mask_ptr, n,
143  nms_overlap_thresh, num_block_cols);
144 
145  // Copy cuda masks to cpu.
146  std::vector<uint64_t> mask_vec(n * num_block_cols);
147  uint64_t *mask = mask_vec.data();
148  CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(mask_vec.data(), mask_ptr,
149  n * num_block_cols * sizeof(uint64_t),
150  cudaMemcpyDeviceToHost));
151  CLOUDVIEWER_CUDA_CHECK(cudaFree(mask_ptr));
152 
153  // Copy sort_indices to cpu.
154  std::vector<int64_t> sort_indices_cpu(n);
155  CLOUDVIEWER_CUDA_CHECK(cudaMemcpy(sort_indices_cpu.data(), sort_indices,
156  n * sizeof(int64_t),
157  cudaMemcpyDeviceToHost));
158 
159  // Write to keep_indices in CPU.
160  // remv_cpu has n bits in total. If the bit is 1, the corresponding
161  // box will be removed.
162  // TODO: This part can be implemented in CUDA. We use the original author's
163  // implementation here.
164  std::vector<uint64_t> remv_cpu(num_block_cols, 0);
165  std::vector<int64_t> keep_indices;
166  for (int i = 0; i < n; i++) {
167  int block_col_idx = i / NMS_BLOCK_SIZE;
168  int inner_block_col_idx = i % NMS_BLOCK_SIZE; // threadIdx.x
169 
170  // Querying the i-th bit in remv_cpu, counted from the right.
171  // - remv_cpu[block_col_idx]: the block bitmap containing the query
172  // - 1ULL << inner_block_col_idx: the one-hot bitmap to extract i
173  if (!(remv_cpu[block_col_idx] & (1ULL << inner_block_col_idx))) {
174  // Keep the i-th box.
175  keep_indices.push_back(sort_indices_cpu[i]);
176 
177  // Any box that overlaps with the i-th box will be removed.
178  uint64_t *p = mask + i * num_block_cols;
179  for (int j = block_col_idx; j < num_block_cols; j++) {
180  remv_cpu[j] |= p[j];
181  }
182  }
183  }
184  CLOUDVIEWER_CUDA_CHECK(cudaFree(sort_indices));
185  return keep_indices;
186 }
187 
188 } // namespace contrib
189 } // namespace ml
190 } // namespace cloudViewer