17 const int num_tensors = tensors.size();
18 const int64_t num_dims = tensors[0].
NumDims();
21 const Device device = tensors[0].GetDevice();
22 const Dtype dtype = tensors[0].GetDtype();
23 SizeVector combined_shape = tensors[0].GetShape();
26 for (
int i = 1; i < num_tensors; ++i) {
30 if (tensors[i].NumDims() != num_dims) {
32 "All the input tensors must have same number of "
33 "dimensions, but the tensor at index 0 has {} dimension(s) "
34 "and the tensor at index {} has {} dimension(s).",
35 num_dims, i, tensors[i].NumDims());
39 for (int64_t j = 0; j < num_dims; ++j) {
40 if (j != axis_d && combined_shape[j] != tensors[i].GetShape(j)) {
42 "All the input tensor dimensions, other than dimension "
43 "size along concatenation axis must be same, but along "
44 "dimension {}, the tensor at index 0 has size {} and "
45 "the tensor at index {} has size {}.",
46 j, combined_shape[j], i, tensors[i].GetShape(j));
50 combined_shape[axis_d] += tensors[i].GetShape(axis_d);
54 std::vector<TensorKey> common_tks;
55 for (
int i = 0; i < axis_d; ++i) {
59 Tensor combined_tensor(combined_shape, dtype, device);
62 int64_t cumulated_length = 0;
63 for (
int i = 0; i < num_tensors; ++i) {
64 const int64_t local_length = tensors[i].GetShape(axis_d);
67 std::vector<TensorKey> local_tks = common_tks;
69 cumulated_length, cumulated_length + local_length, 1));
71 cumulated_length += local_length;
73 combined_tensor.
SetItem(local_tks, tensors[i]);
76 return combined_tensor;
81 const int num_tensors = tensors.size();
83 if (num_tensors < 1) {
86 if (num_tensors == 1) {
87 std::vector<Tensor> split_tensors;
88 split_tensors.reserve(tensors[0].GetLength());
90 for (
int i = 0; i < tensors[0].GetLength(); ++i) {
91 split_tensors.push_back(tensors[0][i]);
98 std::vector<Tensor> flattened_tensors;
99 for (
int i = 0; i < num_tensors; ++i) {
101 flattened_tensors.push_back(
102 tensors[i].Reshape({tensors[i].NumElements(), 1}));
107 if (tensors[0].NumDims() == 0) {
109 "Zero-dimensional tensor can only be concatenated along "
110 "axis = null, but got {}.",
#define AssertTensorDevice(tensor,...)
#define AssertTensorDtype(tensor,...)
void push_back(const T &Elt)
static TensorKey Slice(utility::optional< int64_t > start, utility::optional< int64_t > stop, utility::optional< int64_t > step)
Tensor SetItem(const Tensor &value)
Set all items. Equivalent to tensor[:] = value in Python.
Device GetDevice() const override
Tensor Reshape(const SizeVector &dst_shape) const
SizeVector GetShape() const
constexpr bool has_value() const noexcept
constexpr T const & value() const &
void BinaryEW(const Tensor &lhs, const Tensor &rhs, Tensor &dst, BinaryEWOpCode op_code)
int64_t WrapDim(int64_t dim, int64_t max_dim, bool inclusive)
Wrap around negative dim.
SizeVector BroadcastedShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns the broadcasted shape of two shapes.
Tensor Concatenate(const std::vector< Tensor > &tensors, const utility::optional< int64_t > &axis)
Concatenates the list of tensors in their order, along the given axis into a new tensor....
Tensor Append(const Tensor &self, const Tensor &other, const utility::optional< int64_t > &axis)
Appends the two tensors, along the given axis into a new tensor. Both the tensors must have same data...
Tensor Minimum(const Tensor &input, const Tensor &other)
Computes the element-wise minimum of input and other. The tensors must have same data type and device...
Tensor Maximum(const Tensor &input, const Tensor &other)
Computes the element-wise maximum of input and other. The tensors must have same data type and device...
static Tensor ConcatenateImpl(const std::vector< Tensor > &tensors, const int64_t axis)
Generic file read and write utility for python interface.