1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
8 #include "ml/contrib/TrilinearDevoxelize.cuh"
10 namespace cloudViewer {
14 __global__ void TrilinearDevoxelizeKernel(int b,
21 const float *__restrict__ coords,
22 const float *__restrict__ feat,
23 int *__restrict__ inds,
24 float *__restrict__ wgts,
25 float *__restrict__ outs) {
26 int batch_index = blockIdx.x;
27 int stride = blockDim.x;
28 int index = threadIdx.x;
29 coords += batch_index * n * 3;
30 inds += batch_index * n * 8;
31 wgts += batch_index * n * 8;
32 feat += batch_index * c * r3;
33 outs += batch_index * c * n;
35 for (int i = index; i < n; i += stride) {
37 float y = coords[i + n];
38 float z = coords[i + n + n];
39 float x_lo_f = floorf(x);
40 float y_lo_f = floorf(y);
41 float z_lo_f = floorf(z);
43 float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f)
44 float y_d_1 = y - y_lo_f;
45 float z_d_1 = z - z_lo_f;
46 float x_d_0 = 1.0f - x_d_1;
47 float y_d_0 = 1.0f - y_d_1;
48 float z_d_0 = 1.0f - z_d_1;
50 float wgt000 = x_d_0 * y_d_0 * z_d_0;
51 float wgt001 = x_d_0 * y_d_0 * z_d_1;
52 float wgt010 = x_d_0 * y_d_1 * z_d_0;
53 float wgt011 = x_d_0 * y_d_1 * z_d_1;
54 float wgt100 = x_d_1 * y_d_0 * z_d_0;
55 float wgt101 = x_d_1 * y_d_0 * z_d_1;
56 float wgt110 = x_d_1 * y_d_1 * z_d_0;
57 float wgt111 = x_d_1 * y_d_1 * z_d_1;
59 int x_lo = static_cast<int>(x_lo_f);
60 int y_lo = static_cast<int>(y_lo_f);
61 int z_lo = static_cast<int>(z_lo_f);
62 int x_hi = (x_d_1 > 0) ? -1 : 0;
63 int y_hi = (y_d_1 > 0) ? -1 : 0;
64 int z_hi = (z_d_1 > 0) ? 1 : 0;
66 int idx000 = x_lo * r2 + y_lo * r + z_lo;
67 int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi;
68 int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo;
69 int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi;
70 int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo;
71 int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi;
72 int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo;
73 int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi;
78 wgts[i + n * 2] = wgt010;
79 wgts[i + n * 3] = wgt011;
80 wgts[i + n * 4] = wgt100;
81 wgts[i + n * 5] = wgt101;
82 wgts[i + n * 6] = wgt110;
83 wgts[i + n * 7] = wgt111;
86 inds[i + n * 2] = idx010;
87 inds[i + n * 3] = idx011;
88 inds[i + n * 4] = idx100;
89 inds[i + n * 5] = idx101;
90 inds[i + n * 6] = idx110;
91 inds[i + n * 7] = idx111;
94 for (int j = 0; j < c; j++) {
97 wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] +
98 wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] +
99 wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] +
100 wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111];
105 __global__ void TrilinearDevoxelizeGradKernel(int b,
109 const int *__restrict__ inds,
110 const float *__restrict__ wgts,
111 const float *__restrict__ grad_y,
112 float *__restrict__ grad_x) {
113 int batch_index = blockIdx.x;
114 int stride = blockDim.x;
115 int index = threadIdx.x;
116 inds += batch_index * n * 8;
117 wgts += batch_index * n * 8;
118 grad_x += batch_index * c * r3;
119 grad_y += batch_index * c * n;
121 for (int i = index; i < n; i += stride) {
122 int idx000 = inds[i];
123 int idx001 = inds[i + n];
124 int idx010 = inds[i + n * 2];
125 int idx011 = inds[i + n * 3];
126 int idx100 = inds[i + n * 4];
127 int idx101 = inds[i + n * 5];
128 int idx110 = inds[i + n * 6];
129 int idx111 = inds[i + n * 7];
130 float wgt000 = wgts[i];
131 float wgt001 = wgts[i + n];
132 float wgt010 = wgts[i + n * 2];
133 float wgt011 = wgts[i + n * 3];
134 float wgt100 = wgts[i + n * 4];
135 float wgt101 = wgts[i + n * 5];
136 float wgt110 = wgts[i + n * 6];
137 float wgt111 = wgts[i + n * 7];
139 for (int j = 0; j < c; j++) {
141 float g = grad_y[j * n + i];
142 atomicAdd(grad_x + jr3 + idx000, wgt000 * g);
143 atomicAdd(grad_x + jr3 + idx001, wgt001 * g);
144 atomicAdd(grad_x + jr3 + idx010, wgt010 * g);
145 atomicAdd(grad_x + jr3 + idx011, wgt011 * g);
146 atomicAdd(grad_x + jr3 + idx100, wgt100 * g);
147 atomicAdd(grad_x + jr3 + idx101, wgt101 * g);
148 atomicAdd(grad_x + jr3 + idx110, wgt110 * g);
149 atomicAdd(grad_x + jr3 + idx111, wgt111 * g);
154 } // namespace contrib
156 } // namespace cloudViewer