ACloudViewer  3.9.4
A Modern Library for 3D Data Processing
TensorCheck.cpp
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - CloudViewer: www.cloudViewer.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2024 www.cloudViewer.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
9 
10 #include <Helper.h>
11 #include <Logging.h>
12 
13 #include <string>
14 
16 #include "cloudViewer/core/Dtype.h"
18 
19 namespace cloudViewer {
20 namespace core {
21 namespace tensor_check {
22 
23 void AssertTensorDtype_(const char* file,
24  int line,
25  const char* function,
26  const Tensor& tensor,
27  const Dtype& dtype) {
28  if (tensor.GetDtype() == dtype) {
29  return;
30  }
31  std::string error_message =
32  fmt::format("Tensor has dtype {}, but is expected to have {}.",
33  tensor.GetDtype().ToString(), dtype.ToString());
34  utility::Logger::LogError_(file, line, function, error_message.c_str());
35 }
36 
37 void AssertTensorDtypes_(const char* file,
38  int line,
39  const char* function,
40  const Tensor& tensor,
41  const std::vector<Dtype>& dtypes) {
42  for (auto& it : dtypes) {
43  if (tensor.GetDtype() == it) {
44  return;
45  }
46  }
47 
48  std::vector<std::string> dtype_strings;
49  for (const Dtype& dtype : dtypes) {
50  dtype_strings.push_back(dtype.ToString());
51  }
52  std::string error_message = fmt::format(
53  "Tensor has dtype {}, but is expected to have dtype among {{{}}}.",
54  tensor.GetDtype().ToString(), utility::JoinStrings(dtype_strings));
55  utility::Logger::LogError_(file, line, function, error_message.c_str());
56 }
57 
58 void AssertTensorDevice_(const char* file,
59  int line,
60  const char* function,
61  const Tensor& tensor,
62  const Device& device) {
63  if (tensor.GetDevice() == device) {
64  return;
65  }
66  std::string error_message =
67  fmt::format("Tensor has device {}, but is expected to have {}.",
68  tensor.GetDevice().ToString(), device.ToString());
69  utility::Logger::LogError_(file, line, function, error_message.c_str());
70 }
71 
72 void AssertTensorShape_(const char* file,
73  int line,
74  const char* function,
75  const Tensor& tensor,
76  const DynamicSizeVector& shape) {
77  if (shape.IsDynamic()) {
78  if (tensor.GetShape().IsCompatible(shape)) {
79  return;
80  }
81  std::string error_message = fmt::format(
82  "Tensor has shape {}, but is expected to have compatible with "
83  "{}.",
84  tensor.GetShape().ToString(), shape.ToString());
85  utility::Logger::LogError_(file, line, function, error_message.c_str());
86  } else {
87  SizeVector static_shape = shape.ToSizeVector();
88  if (tensor.GetShape() == static_shape) {
89  return;
90  }
91  std::string error_message = fmt::format(
92  "Tensor has shape {}, but is expected to have {}.",
93  tensor.GetShape().ToString(), static_shape.ToString());
94  utility::Logger::LogError_(file, line, function, error_message.c_str());
95  }
96 }
97 
98 } // namespace tensor_check
99 } // namespace core
100 } // namespace cloudViewer
filament::Texture::InternalFormat format
std::string ToString() const
Returns string representation of device, e.g. "CPU:0", "CUDA:0".
Definition: Device.cpp:89
std::string ToString() const
Definition: Dtype.h:65
bool IsCompatible(const DynamicSizeVector &dsv) const
Definition: SizeVector.cpp:149
std::string ToString() const
Definition: SizeVector.cpp:132
Dtype GetDtype() const
Definition: Tensor.h:1164
Device GetDevice() const override
Definition: Tensor.cpp:1435
SizeVector GetShape() const
Definition: Tensor.h:1127
static void LogError_(const char *file, int line, const char *function, const char *format, Args &&...args)
Definition: Logging.h:189
Helper functions for the ml ops.
void AssertTensorDtypes_(const char *file, int line, const char *function, const Tensor &tensor, const std::vector< Dtype > &dtypes)
Definition: TensorCheck.cpp:37
void AssertTensorDtype_(const char *file, int line, const char *function, const Tensor &tensor, const Dtype &dtype)
Definition: TensorCheck.cpp:23
void AssertTensorDevice_(const char *file, int line, const char *function, const Tensor &tensor, const Device &device)
Definition: TensorCheck.cpp:58
void AssertTensorShape_(const char *file, int line, const char *function, const Tensor &tensor, const DynamicSizeVector &shape)
Definition: TensorCheck.cpp:72
std::string JoinStrings(const std::vector< std::string > &strs, const std::string &delimiter=", ")
Definition: Helper.cpp:168
Generic file read and write utility for python interface.