1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "ml/contrib/InterpolatePoints.cuh"
10 namespace cloudViewer {
14 __global__ void three_nn_kernel(int b,
17 const float *__restrict__ unknown,
18 const float *__restrict__ known,
19 float *__restrict__ dist2,
20 int *__restrict__ idx) {
27 int bs_idx = blockIdx.y;
28 int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
29 if (bs_idx >= b || pt_idx >= n) return;
31 unknown += bs_idx * n * 3 + pt_idx * 3;
32 known += bs_idx * m * 3;
33 dist2 += bs_idx * n * 3 + pt_idx * 3;
34 idx += bs_idx * n * 3 + pt_idx * 3;
36 float ux = unknown[0];
37 float uy = unknown[1];
38 float uz = unknown[2];
40 double best1 = 1e40, best2 = 1e40, best3 = 1e40;
41 int besti1 = 0, besti2 = 0, besti3 = 0;
42 for (int k = 0; k < m; ++k) {
43 float x = known[k * 3 + 0];
44 float y = known[k * 3 + 1];
45 float z = known[k * 3 + 2];
47 (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
55 } else if (d < best2) {
60 } else if (d < best3) {
73 __global__ void three_interpolate_kernel(int b,
77 const float *__restrict__ points,
78 const int *__restrict__ idx,
79 const float *__restrict__ weight,
80 float *__restrict__ out) {
87 int bs_idx = blockIdx.z;
88 int c_idx = blockIdx.y;
89 int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
91 if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
93 weight += bs_idx * n * 3 + pt_idx * 3;
94 points += bs_idx * c * m + c_idx * m;
95 idx += bs_idx * n * 3 + pt_idx * 3;
96 out += bs_idx * c * n + c_idx * n;
98 out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] +
99 weight[2] * points[idx[2]];
102 __global__ void three_interpolate_grad_kernel(
107 const float *__restrict__ grad_out,
108 const int *__restrict__ idx,
109 const float *__restrict__ weight,
110 float *__restrict__ grad_points) {
111 // grad_out: (B, C, N)
114 // grad_points: (B, C, M)
116 int bs_idx = blockIdx.z;
117 int c_idx = blockIdx.y;
118 int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
120 if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
122 grad_out += bs_idx * c * n + c_idx * n + pt_idx;
123 weight += bs_idx * n * 3 + pt_idx * 3;
124 grad_points += bs_idx * c * m + c_idx * m;
125 idx += bs_idx * n * 3 + pt_idx * 3;
127 atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
128 atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
129 atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
132 } // namespace contrib
134 } // namespace cloudViewer