20 namespace shape_util {
31 const SizeVector& r_shape);
41 const SizeVector& r_shape);
51 const SizeVector& dst_shape);
61 const SizeVector& dims,
72 int64_t
WrapDim(int64_t dim, int64_t max_dim,
bool inclusive =
false);
80 SizeVector
InferShape(SizeVector shape, int64_t num_elements);
83 SizeVector
Concat(
const SizeVector& l_shape,
const SizeVector& r_shape);
86 SizeVector
Iota(int64_t n);
99 std::pair<bool, SizeVector>
Restride(
const SizeVector& old_shape,
100 const SizeVector& old_strides,
101 const SizeVector& new_shape);
SizeVector Concat(const SizeVector &l_shape, const SizeVector &r_shape)
Concatenate two shapes.
int64_t WrapDim(int64_t dim, int64_t max_dim, bool inclusive)
Wrap around negative dim.
bool CanBeBrocastedToShape(const SizeVector &src_shape, const SizeVector &dst_shape)
Returns true if src_shape can be brocasted to dst_shape.
SizeVector BroadcastedShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns the broadcasted shape of two shapes.
SizeVector ReductionShape(const SizeVector &src_shape, const SizeVector &dims, bool keepdim)
Returns the shape after reduction.
std::pair< bool, SizeVector > Restride(const SizeVector &old_shape, const SizeVector &old_strides, const SizeVector &new_shape)
SizeVector Iota(int64_t n)
Returns a SizeVector of {0, 1, ..., n - 1}, similar to std::iota.
SizeVector InferShape(SizeVector shape, int64_t num_elements)
SizeVector DefaultStrides(const SizeVector &shape)
Compute default strides for a shape when a tensor is contiguous.
bool IsCompatibleBroadcastShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns true if two shapes are compatible for broadcasting.
Generic file read and write utility for python interface.