# ----------------------------------------------------------------------------
# - CloudViewer: www.cloudViewer.org -
# ----------------------------------------------------------------------------
# Copyright (c) 2018-2024 www.cloudViewer.org
# SPDX-License-Identifier: MIT
# ----------------------------------------------------------------------------
import torch
import numpy as np
__all__ = ['RaggedTensor']
[docs]class RaggedTensor:
"""RaggedTensor.
A RaggedTensor is a tensor with ragged dimension, whose slice may
have different lengths. We define a container for ragged tensor to
support operations involving batches whose elements may have different
shape.
"""
[docs] def __init__(self, r_tensor, internal=False):
"""Creates a `RaggedTensor` with specified torch script object.
This constructor is private -- please use one of the following ops
to build `RaggedTensor`'s:
* `ml3d.classes.RaggedTensor.from_row_splits`
Raises:
ValueError: If internal = False. This method is intented for internal use.
"""
if not internal:
raise ValueError(
"RaggedTensor constructor is private, please use one of the factory method instead(e.g. RaggedTensor.from_row_splits())"
)
self.r_tensor = r_tensor
[docs] @classmethod
def from_row_splits(cls, values, row_splits, validate=True, copy=True):
"""Creates a RaggedTensor with rows partitioned by row_splits.
The returned `RaggedTensor` corresponds with the python list defined by::
result = [values[row_splits[i]:row_splits[i + 1]]
for i in range(len(row_splits) - 1)]
Args:
values: A Tensor with shape [N, None].
row_splits: A 1-D integer tensor with shape `[N+1]`. Must not be
empty, and must be stored in ascending order. `row_splits[0]` must
be zero and `row_splits[-1]` must be `N`.
validate: Verify that `row_splits` are compatible with `values`.
Set it to False to avoid expensive checks.
copy: Whether to do a deep copy for `values` and `row_splits`.
Set it to False to save memory for short term usage.
Returns:
A `RaggedTensor` container.
Example:
>>> print(ml3d.classes.RaggedTensor.from_row_splits(
... values=[3, 1, 4, 1, 5, 9, 2, 6],
... row_splits=[0, 4, 4, 7, 8, 8]))
<RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
"""
if isinstance(values, list):
values = torch.tensor(values, dtype=torch.float64)
elif isinstance(values, np.ndarray):
values = torch.from_numpy(values)
elif isinstance(values, torch.Tensor) and copy:
values = values.clone()
if isinstance(row_splits, list):
row_splits = torch.tensor(row_splits, dtype=torch.int64)
elif isinstance(row_splits, np.ndarray):
row_splits = torch.from_numpy(row_splits)
elif isinstance(row_splits, torch.Tensor) and copy:
row_splits = row_splits.clone()
r_tensor = torch.classes.my_classes.RaggedTensor().from_row_splits(
values, row_splits, validate)
return cls(r_tensor, internal=True)
@property
def values(self):
"""The concatenated rows for this ragged tensor."""
return self.r_tensor.get_values()
@property
def row_splits(self):
"""The row-split indices for this ragged tensor's `values`."""
return self.r_tensor.get_row_splits()
@property
def dtype(self):
"""The `DType` of values in this ragged tensor."""
return self.values.dtype
@property
def device(self):
"""The device of values in this ragged tensor."""
return self.values.device
@property
def shape(self):
"""The statically known shape of this ragged tensor."""
return [len(self.r_tensor), None, *self.values.shape[1:]]
@property
def requires_grad(self):
"""Read/writeble `requires_grad` for values."""
return self.values.requires_grad
@requires_grad.setter
def requires_grad(self, value):
self.values.requires_grad = value
[docs] def clone(self):
"""Returns a clone of object."""
return RaggedTensor(self.r_tensor.clone(), True)
[docs] def to_list(self):
"""Returns a list of tensors"""
return [tensor for tensor in self.r_tensor]
def __getitem__(self, idx):
return self.r_tensor[idx]
def __repr__(self):
return f"RaggedTensor(values={self.values}, row_splits={self.row_splits})"
def __len__(self):
return len(self.r_tensor)
def __add__(self, other):
return RaggedTensor(self.r_tensor + self.__convert_to_tensor(other),
True)
def __iadd__(self, other):
self.r_tensor += self.__convert_to_tensor(other)
return self
def __sub__(self, other):
return RaggedTensor(self.r_tensor - self.__convert_to_tensor(other),
True)
def __isub__(self, other):
self.r_tensor -= self.__convert_to_tensor(other)
return self
def __mul__(self, other):
return RaggedTensor(self.r_tensor * self.__convert_to_tensor(other),
True)
def __imul__(self, other):
self.r_tensor *= self.__convert_to_tensor(other)
return self
def __truediv__(self, other):
return RaggedTensor(self.r_tensor / self.__convert_to_tensor(other),
True)
def __itruediv__(self, other):
self.r_tensor /= self.__convert_to_tensor(other)
return self
def __floordiv__(self, other):
return RaggedTensor(self.r_tensor // self.__convert_to_tensor(other),
True)
def __ifloordiv__(self, other):
self.r_tensor //= self.__convert_to_tensor(other)
return self
def __convert_to_tensor(self, value):
"""Converts scalar/tensor/RaggedTensor to torch.Tensor"""
if isinstance(value, RaggedTensor):
if self.row_splits.shape != value.row_splits.shape or torch.any(
self.row_splits != value.row_splits).item():
raise ValueError(
f"Incompatible shape : {self.row_splits} and {value.row_splits}"
)
return value.values
elif isinstance(value, torch.Tensor):
return value
elif isinstance(value, (int, float, bool)):
return torch.Tensor([value]).to(type(value))
else:
raise ValueError(f"Unknown type : {type(value)}")