ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
InterpolatePoints.cu
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include "ml/contrib/InterpolatePoints.cuh"
9 
10 namespace cloudViewer {
11 namespace ml {
12 namespace contrib {
13 
14 __global__ void three_nn_kernel(int b,
15  int n,
16  int m,
17  const float *__restrict__ unknown,
18  const float *__restrict__ known,
19  float *__restrict__ dist2,
20  int *__restrict__ idx) {
21  // unknown: (B, N, 3)
22  // known: (B, M, 3)
23  // output:
24  // dist2: (B, N, 3)
25  // idx: (B, N, 3)
26 
27  int bs_idx = blockIdx.y;
28  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
29  if (bs_idx >= b || pt_idx >= n) return;
30 
31  unknown += bs_idx * n * 3 + pt_idx * 3;
32  known += bs_idx * m * 3;
33  dist2 += bs_idx * n * 3 + pt_idx * 3;
34  idx += bs_idx * n * 3 + pt_idx * 3;
35 
36  float ux = unknown[0];
37  float uy = unknown[1];
38  float uz = unknown[2];
39 
40  double best1 = 1e40, best2 = 1e40, best3 = 1e40;
41  int besti1 = 0, besti2 = 0, besti3 = 0;
42  for (int k = 0; k < m; ++k) {
43  float x = known[k * 3 + 0];
44  float y = known[k * 3 + 1];
45  float z = known[k * 3 + 2];
46  float d =
47  (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
48  if (d < best1) {
49  best3 = best2;
50  besti3 = besti2;
51  best2 = best1;
52  besti2 = besti1;
53  best1 = d;
54  besti1 = k;
55  } else if (d < best2) {
56  best3 = best2;
57  besti3 = besti2;
58  best2 = d;
59  besti2 = k;
60  } else if (d < best3) {
61  best3 = d;
62  besti3 = k;
63  }
64  }
65  dist2[0] = best1;
66  dist2[1] = best2;
67  dist2[2] = best3;
68  idx[0] = besti1;
69  idx[1] = besti2;
70  idx[2] = besti3;
71 }
72 
73 __global__ void three_interpolate_kernel(int b,
74  int c,
75  int m,
76  int n,
77  const float *__restrict__ points,
78  const int *__restrict__ idx,
79  const float *__restrict__ weight,
80  float *__restrict__ out) {
81  // points: (B, C, M)
82  // idx: (B, N, 3)
83  // weight: (B, N, 3)
84  // output:
85  // out: (B, C, N)
86 
87  int bs_idx = blockIdx.z;
88  int c_idx = blockIdx.y;
89  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
90 
91  if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
92 
93  weight += bs_idx * n * 3 + pt_idx * 3;
94  points += bs_idx * c * m + c_idx * m;
95  idx += bs_idx * n * 3 + pt_idx * 3;
96  out += bs_idx * c * n + c_idx * n;
97 
98  out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
99  weight[2] * points[idx[2]];
100 }
101 
102 __global__ void three_interpolate_grad_kernel(
103  int b,
104  int c,
105  int n,
106  int m,
107  const float *__restrict__ grad_out,
108  const int *__restrict__ idx,
109  const float *__restrict__ weight,
110  float *__restrict__ grad_points) {
111  // grad_out: (B, C, N)
112  // weight: (B, N, 3)
113  // output:
114  // grad_points: (B, C, M)
115 
116  int bs_idx = blockIdx.z;
117  int c_idx = blockIdx.y;
118  int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
119 
120  if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
121 
122  grad_out += bs_idx * c * n + c_idx * n + pt_idx;
123  weight += bs_idx * n * 3 + pt_idx * 3;
124  grad_points += bs_idx * c * m + c_idx * m;
125  idx += bs_idx * n * 3 + pt_idx * 3;
126 
127  atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
128  atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
129  atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
130 }
131 
132 } // namespace contrib
133 } // namespace ml
134 } // namespace cloudViewer