Source code for cloudViewer.ml.torch.classes.ragged_tensor

# ----------------------------------------------------------------------------
# -                        CloudViewer: www.cloudViewer.org                  -
# ----------------------------------------------------------------------------
# Copyright (c) 2018-2024 www.cloudViewer.org
# SPDX-License-Identifier: MIT
# ----------------------------------------------------------------------------

import torch
import numpy as np

__all__ = ['RaggedTensor']


[docs]class RaggedTensor: """RaggedTensor. A RaggedTensor is a tensor with ragged dimension, whose slice may have different lengths. We define a container for ragged tensor to support operations involving batches whose elements may have different shape. """
[docs] def __init__(self, r_tensor, internal=False): """Creates a `RaggedTensor` with specified torch script object. This constructor is private -- please use one of the following ops to build `RaggedTensor`'s: * `ml3d.classes.RaggedTensor.from_row_splits` Raises: ValueError: If internal = False. This method is intented for internal use. """ if not internal: raise ValueError( "RaggedTensor constructor is private, please use one of the factory method instead(e.g. RaggedTensor.from_row_splits())" ) self.r_tensor = r_tensor
[docs] @classmethod def from_row_splits(cls, values, row_splits, validate=True, copy=True): """Creates a RaggedTensor with rows partitioned by row_splits. The returned `RaggedTensor` corresponds with the python list defined by:: result = [values[row_splits[i]:row_splits[i + 1]] for i in range(len(row_splits) - 1)] Args: values: A Tensor with shape [N, None]. row_splits: A 1-D integer tensor with shape `[N+1]`. Must not be empty, and must be stored in ascending order. `row_splits[0]` must be zero and `row_splits[-1]` must be `N`. validate: Verify that `row_splits` are compatible with `values`. Set it to False to avoid expensive checks. copy: Whether to do a deep copy for `values` and `row_splits`. Set it to False to save memory for short term usage. Returns: A `RaggedTensor` container. Example: >>> print(ml3d.classes.RaggedTensor.from_row_splits( ... values=[3, 1, 4, 1, 5, 9, 2, 6], ... row_splits=[0, 4, 4, 7, 8, 8])) <RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]> """ if isinstance(values, list): values = torch.tensor(values, dtype=torch.float64) elif isinstance(values, np.ndarray): values = torch.from_numpy(values) elif isinstance(values, torch.Tensor) and copy: values = values.clone() if isinstance(row_splits, list): row_splits = torch.tensor(row_splits, dtype=torch.int64) elif isinstance(row_splits, np.ndarray): row_splits = torch.from_numpy(row_splits) elif isinstance(row_splits, torch.Tensor) and copy: row_splits = row_splits.clone() r_tensor = torch.classes.my_classes.RaggedTensor().from_row_splits( values, row_splits, validate) return cls(r_tensor, internal=True)
@property def values(self): """The concatenated rows for this ragged tensor.""" return self.r_tensor.get_values() @property def row_splits(self): """The row-split indices for this ragged tensor's `values`.""" return self.r_tensor.get_row_splits() @property def dtype(self): """The `DType` of values in this ragged tensor.""" return self.values.dtype @property def device(self): """The device of values in this ragged tensor.""" return self.values.device @property def shape(self): """The statically known shape of this ragged tensor.""" return [len(self.r_tensor), None, *self.values.shape[1:]] @property def requires_grad(self): """Read/writeble `requires_grad` for values.""" return self.values.requires_grad @requires_grad.setter def requires_grad(self, value): self.values.requires_grad = value
[docs] def clone(self): """Returns a clone of object.""" return RaggedTensor(self.r_tensor.clone(), True)
[docs] def to_list(self): """Returns a list of tensors""" return [tensor for tensor in self.r_tensor]
def __getitem__(self, idx): return self.r_tensor[idx] def __repr__(self): return f"RaggedTensor(values={self.values}, row_splits={self.row_splits})" def __len__(self): return len(self.r_tensor) def __add__(self, other): return RaggedTensor(self.r_tensor + self.__convert_to_tensor(other), True) def __iadd__(self, other): self.r_tensor += self.__convert_to_tensor(other) return self def __sub__(self, other): return RaggedTensor(self.r_tensor - self.__convert_to_tensor(other), True) def __isub__(self, other): self.r_tensor -= self.__convert_to_tensor(other) return self def __mul__(self, other): return RaggedTensor(self.r_tensor * self.__convert_to_tensor(other), True) def __imul__(self, other): self.r_tensor *= self.__convert_to_tensor(other) return self def __truediv__(self, other): return RaggedTensor(self.r_tensor / self.__convert_to_tensor(other), True) def __itruediv__(self, other): self.r_tensor /= self.__convert_to_tensor(other) return self def __floordiv__(self, other): return RaggedTensor(self.r_tensor // self.__convert_to_tensor(other), True) def __ifloordiv__(self, other): self.r_tensor //= self.__convert_to_tensor(other) return self def __convert_to_tensor(self, value): """Converts scalar/tensor/RaggedTensor to torch.Tensor""" if isinstance(value, RaggedTensor): if self.row_splits.shape != value.row_splits.shape or torch.any( self.row_splits != value.row_splits).item(): raise ValueError( f"Incompatible shape : {self.row_splits} and {value.row_splits}" ) return value.values elif isinstance(value, torch.Tensor): return value elif isinstance(value, (int, float, bool)): return torch.Tensor([value]).to(type(value)) else: raise ValueError(f"Unknown type : {type(value)}")