ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
FileSPLAT.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #include <FileSystem.h>
9 #include <Logging.h>
10 #include <ProgressReporters.h>
11 #include <tbb/parallel_sort.h>
12 
13 #include <Eigen/Dense>
14 #include <cmath>
15 #include <cstring>
16 #include <fstream>
17 #include <vector>
18 
19 #include "cloudViewer/core/Dtype.h"
23 
24 namespace cloudViewer {
25 namespace t {
26 namespace io {
27 
28 namespace {
29 
30 constexpr double SH_C0 = 0.28209479177387814;
31 constexpr int SPLAT_GAUSSIAN_BYTE_SIZE = 32;
32 
33 // Sigmoid function for opacity calculation
34 inline double sigmoid(double x) { return 1.0 / (1.0 + std::exp(-x)); }
35 
36 template <typename scalar_t>
37 Eigen::Array<uint8_t, 4, 1> ComputeColor(const scalar_t *f_dc_ptr,
38  const scalar_t *opacity_ptr) {
39  Eigen::Array<float, 4, 1> color;
40  color[0] = 0.5 + SH_C0 * f_dc_ptr[0];
41  color[1] = 0.5 + SH_C0 * f_dc_ptr[1];
42  color[2] = 0.5 + SH_C0 * f_dc_ptr[2];
43  color[3] = sigmoid(*opacity_ptr);
44  // Convert color to int (scale, clip, and cast)
45  return (color * 255).round().cwiseMin(255.0).cwiseMax(0.0).cast<uint8_t>();
46 }
47 
50 std::vector<int64_t> SortedSplatIndices(geometry::TensorMap &t_map) {
51  auto num_gaussians = t_map["opacity"].GetShape(0);
52  std::vector<int64_t> indices(num_gaussians);
53  std::iota(indices.begin(), indices.end(), 0);
54 
55  // Get pointers to data
56  const float *scale_data = t_map["scale"].GetDataPtr<float>();
57  const float *opacity_data = t_map["opacity"].GetDataPtr<float>();
58  const auto scle_grp_size = t_map["scale"].GetShape(1);
59 
60  // Custom sorting function using the given formula
61  tbb::parallel_sort(
62  indices.begin(), indices.end(),
63  [&](size_t left, size_t right) -> bool {
64  // Compute scores for left and right elements
65  float scale_left = scale_data[left * scle_grp_size] +
66  scale_data[left * scle_grp_size + 1] +
67  scale_data[left * scle_grp_size + 2];
68  float scale_right = scale_data[right * scle_grp_size] +
69  scale_data[right * scle_grp_size + 1] +
70  scale_data[right * scle_grp_size + 2];
71 
72  float score_left = -std::exp(scale_left) /
73  (1 + std::exp(-opacity_data[left]));
74  float score_right = -std::exp(scale_right) /
75  (1 + std::exp(-opacity_data[right]));
76 
77  return score_left < score_right; // Sort in descending order
78  });
79  return indices;
80 }
81 
82 } // End of anonymous namespace
83 
85  const std::string &filename,
86  geometry::PointCloud &pointcloud,
88  try {
89  // Open the file
91  if (!file.Open(filename, "rb")) {
92  utility::LogWarning("Read SPLAT failed: unable to open file: {}",
93  filename);
94  return false;
95  }
96  pointcloud.Clear();
97 
98  size_t file_size = file.GetFileSize();
99  if (file_size == 0 || file_size % SPLAT_GAUSSIAN_BYTE_SIZE > 0) {
101  "Read SPLAT failed: file {} does not contain "
102  "a whole number of Gaussians. File Size {}"
103  " bytes, Gaussian Size {} bytes.",
104  filename, file_size, SPLAT_GAUSSIAN_BYTE_SIZE);
105  return false;
106  }
107 
108  // Report progress
109  utility::CountingProgressReporter reporter(params.update_progress);
110  reporter.SetTotal(file.GetFileSize());
111 
112  // Constants
113  char buffer[SPLAT_GAUSSIAN_BYTE_SIZE];
114  const char *buffer_position = buffer;
115  const char *buffer_scale = buffer_position + 3 * sizeof(float);
116  const uint8_t *buffer_color =
117  reinterpret_cast<const uint8_t *>(buffer_scale) +
118  3 * sizeof(float);
119  const uint8_t *buffer_rotation = buffer_color + 4 * sizeof(uint8_t);
120  int number_of_points =
121  static_cast<int>(file_size / SPLAT_GAUSSIAN_BYTE_SIZE);
122 
123  // Positions
124  pointcloud.SetPointPositions(
125  core::Tensor::Empty({number_of_points, 3}, core::Float32));
126  // Scale
127  pointcloud.SetPointAttr(
128  "scale",
129  core::Tensor::Empty({number_of_points, 3}, core::Float32));
130  // Rots
131  pointcloud.SetPointAttr(
132  "rot",
133  core::Tensor::Empty({number_of_points, 4}, core::Float32));
134  // f_dc
135  pointcloud.SetPointAttr(
136  "f_dc",
137  core::Tensor::Empty({number_of_points, 3}, core::Float32));
138  // Opacity
139  pointcloud.SetPointAttr(
140  "opacity",
141  core::Tensor::Empty({number_of_points, 1}, core::Float32));
142 
143  float *position_ptr =
144  pointcloud.GetPointPositions().GetDataPtr<float>();
145  float *scale_ptr = pointcloud.GetPointAttr("scale").GetDataPtr<float>();
146  float *f_dc_ptr = pointcloud.GetPointAttr("f_dc").GetDataPtr<float>();
147  float *opacity_ptr =
148  pointcloud.GetPointAttr("opacity").GetDataPtr<float>();
149  float *rot_ptr = pointcloud.GetPointAttr("rot").GetDataPtr<float>();
150 
151  // Read the data
152  for (size_t index = 0; file.ReadData(buffer, SPLAT_GAUSSIAN_BYTE_SIZE);
153  ++index) {
154  // Copy the data into the vectors
155  std::memcpy(position_ptr + index * 3, buffer_position,
156  3 * sizeof(float));
157  std::memcpy(scale_ptr + index * 3, buffer_scale, 3 * sizeof(float));
158 
159  // Calculate the f_dc
160  float *f_dc = f_dc_ptr + index * 3;
161  for (int i = 0; i < 3; i++) {
162  f_dc[i] = ((buffer_color[i] / 255.0) - 0.5) / SH_C0;
163  }
164  // Calculate the opacity
165  float *opacity = opacity_ptr + index;
166  if (buffer_color[3] == 0) {
167  opacity[0] = 0.0f; // Handle division by zero
168  } else if (buffer_color[3] == 255) {
169  opacity[0] = -std::numeric_limits<float>::lowest(); // -log(0)
170  } else {
171  opacity[0] = -log(1 / (buffer_color[3] / 255.0) - 1);
172  }
173  // Calculate the rotation quaternion.
174  // Normalize to reduce quantization error
175  float *rot_float = rot_ptr + index * 4;
176  float quat_norm = 0;
177  for (int i = 0; i < 4; i++) {
178  rot_float[i] = (buffer_rotation[i] / 128.0) - 1.0;
179  quat_norm += rot_float[i] * rot_float[i];
180  }
181  quat_norm = sqrt(quat_norm);
182  if (quat_norm > std::numeric_limits<float>::epsilon()) {
183  for (int i = 0; i < 4; i++) {
184  rot_float[i] /= quat_norm;
185  }
186  } else { // gsplat quat convention is wxyz
187  rot_float[0] = 1.0f;
188  rot_float[1] = 0.0f;
189  rot_float[2] = 0.0f;
190  rot_float[3] = 0.0f;
191  }
192 
193  if (index % 1000 == 0) {
194  reporter.Update(file.CurPos());
195  }
196  }
197 
198  // Report progress
199  reporter.Finish();
200  return true;
201  } catch (const std::exception &e) {
202  utility::LogError("Read SPLAT file {} failed: {}", filename, e.what());
203  }
204  return false;
205 }
206 
208  const std::string &filename,
209  const geometry::PointCloud &pointcloud,
211  // Validate Point Cloud
212  if (pointcloud.IsEmpty()) {
213  utility::LogWarning("Write SPLAT failed: point cloud has 0 points.");
214  return false;
215  }
216 
217  // Validate Splat Data
218  if (!pointcloud.IsGaussianSplat()) {
220  "Write SPLAT failed: point cloud is not a Gaussian Splat.");
221  return false;
222  }
223  geometry::TensorMap t_map = pointcloud.GetPointAttr();
224 
225  // Convert to float32, make contiguous and move to CPU.
226  // Some of these operations may be no-ops. This specific order of
227  // operations ensures efficiency.
228  for (auto attr : {"positions", "scale", "rot", "f_dc", "opacity"}) {
229  t_map[attr] = t_map[attr]
230  .To(core::Float32)
231  .Contiguous()
232  .To(core::Device("CPU:0"));
233  }
234  float *positions_ptr = t_map["positions"].GetDataPtr<float>();
235  float *scale_ptr = t_map["scale"].GetDataPtr<float>();
236  float *f_dc_ptr = t_map["f_dc"].GetDataPtr<float>();
237  float *opacity_ptr = t_map["opacity"].GetDataPtr<float>();
238  float *rot_ptr = t_map["rot"].GetDataPtr<float>();
239  constexpr int N_POSITIONS = 3;
240  constexpr int N_SCALE = 3;
241  constexpr int N_F_DC = 3;
242  constexpr int N_OPACITY = 1;
243  constexpr int N_ROT = 4;
244 
245  // Total Gaussians
246  long num_gaussians =
247  static_cast<long>(pointcloud.GetPointPositions().GetLength());
248 
249  // Open splat file
250  auto splat_file = std::ofstream(filename, std::ios::binary);
251  try {
252  splat_file.exceptions(std::ofstream::badbit); // failbit not set for
253  // binary IO errors
254  } catch (const std::ios_base::failure &) {
255  utility::LogWarning("Write SPLAT failed: unable to open file: {}.",
256  filename);
257  return false;
258  }
259 
260  // Write to SPLAT
261  utility::CountingProgressReporter reporter(params.update_progress);
262  reporter.SetTotal(num_gaussians);
263 
264  std::vector<int64_t> sorted_indices = SortedSplatIndices(t_map);
265 
266  try {
267  for (int64_t i = 0; i < num_gaussians; i++) {
268  int64_t g_idx = sorted_indices[i];
269 
270  // Positions
271  splat_file.write(reinterpret_cast<const char *>(
272  positions_ptr + N_POSITIONS * g_idx),
273  N_POSITIONS * sizeof(float));
274 
275  // Scale
276  splat_file.write(
277  reinterpret_cast<const char *>(scale_ptr + N_SCALE * g_idx),
278  N_SCALE * sizeof(float));
279 
280  // Color
281  auto color = ComputeColor(f_dc_ptr + N_F_DC * g_idx,
282  opacity_ptr + N_OPACITY * g_idx);
283  splat_file.write(reinterpret_cast<const char *>(color.data()),
284  4 * sizeof(uint8_t));
285 
286  // Rot
287  int rot_offset = N_ROT * g_idx;
288  Eigen::Vector4f rot{rot_ptr[rot_offset], rot_ptr[rot_offset + 1],
289  rot_ptr[rot_offset + 2],
290  rot_ptr[rot_offset + 3]};
291  if (auto quat_norm = rot.norm();
292  quat_norm > std::numeric_limits<float>::epsilon()) {
293  rot /= quat_norm;
294  } else {
295  rot = {1.f, 0.f, 0.f, 0.f}; // wxyz quaternion
296  }
297  // offset should be 127, but we follow the reference
298  // antimatter/convert.py code
299  rot = (rot * 128.0).array().round() + 128.0;
300  auto uint8_rot =
301  rot.cwiseMin(255.0).cwiseMax(0.0).cast<uint8_t>().eval();
302  splat_file.write(reinterpret_cast<const char *>(uint8_rot.data()),
303  4 * sizeof(uint8_t));
304 
305  if (i % 1000 == 0) {
306  reporter.Update(i);
307  }
308  }
309  splat_file.close(); // Close file, flushes to disk.
310  reporter.Finish();
311  return true;
312  } catch (const std::ios_base::failure &e) {
313  utility::LogWarning("Write SPLAT to file {} failed: {}", filename,
314  e.what());
315  return false;
316  }
317 }
318 
319 } // namespace io
320 } // namespace t
321 } // namespace cloudViewer
std::string filename
math::float4 color
cmdLineReadable * params[]
int64_t GetLength() const
Definition: Tensor.h:1125
static Tensor Empty(const SizeVector &shape, Dtype dtype, const Device &device=Device("CPU:0"))
Create a tensor with uninitialized values.
Definition: Tensor.cpp:400
A point cloud contains a list of 3D points.
Definition: PointCloud.h:82
void SetPointPositions(const core::Tensor &value)
Set the value of the "positions" attribute. Convenience function.
Definition: PointCloud.h:186
core::Tensor & GetPointPositions()
Get the value of the "positions" attribute. Convenience function.
Definition: PointCloud.h:124
bool IsEmpty() const override
Returns !HasPointPositions().
Definition: PointCloud.h:257
const TensorMap & GetPointAttr() const
Getter for point_attr_ TensorMap. Used in Pybind.
Definition: PointCloud.h:111
PointCloud & Clear() override
Clear all data in the point cloud.
Definition: PointCloud.h:251
void SetPointAttr(const std::string &key, const core::Tensor &value)
Definition: PointCloud.h:177
bool Open(const std::string &filename, const std::string &mode)
Open a file.
Definition: FileSystem.cpp:739
int64_t CurPos()
Returns current position in the file (ftell).
Definition: FileSystem.cpp:760
int64_t GetFileSize()
Returns the file size in bytes.
Definition: FileSystem.cpp:772
size_t ReadData(T *data, size_t num_elems)
Definition: FileSystem.h:253
#define LogWarning(...)
Definition: Logging.h:72
#define LogError(...)
Definition: Logging.h:60
const Dtype Float32
Definition: Dtype.cpp:42
bool ReadPointCloudFromSPLAT(const std::string &filename, geometry::PointCloud &pointcloud, const cloudViewer::io::ReadPointCloudOption &params)
Definition: FileSPLAT.cpp:84
bool WritePointCloudToSPLAT(const std::string &filename, const geometry::PointCloud &pointcloud, const cloudViewer::io::WritePointCloudOption &params)
Definition: FileSPLAT.cpp:207
Generic file read and write utility for python interface.
Optional parameters to ReadPointCloud.
Definition: FileIO.h:39
Optional parameters to WritePointCloud.
Definition: FileIO.h:77