17 namespace shape_util {
22 if (ndims <
static_cast<int64_t
>(shape.
size())) {
28 expanded_shape.
begin() + ndims - shape.
size());
29 return expanded_shape;
34 int64_t l_ndims = l_shape.
size();
35 int64_t r_ndims = r_shape.
size();
37 if (l_ndims == 0 || r_ndims == 0) {
45 int64_t shorter_ndims =
std::min(l_ndims, r_ndims);
46 for (int64_t i = 0; i < shorter_ndims; ++i) {
47 int64_t l_dim = l_shape[l_ndims - 1 - i];
48 int64_t r_dim = r_shape[r_ndims - 1 - i];
49 if (!(l_dim == r_dim || l_dim == 1 || r_dim == 1)) {
63 int64_t l_ndims = l_shape.
size();
64 int64_t r_ndims = r_shape.
size();
65 int64_t out_ndims =
std::max(l_ndims, r_ndims);
72 for (int64_t i = 0; i < out_ndims; i++) {
73 if (l_shape_filled[i] == 1) {
74 broadcasted_shape[i] = r_shape_filled[i];
75 }
else if (r_shape_filled[i] == 1) {
76 broadcasted_shape[i] = l_shape_filled[i];
77 }
else if (l_shape_filled[i] == r_shape_filled[i]) {
78 broadcasted_shape[i] = l_shape_filled[i];
81 "Internal error: dimension size {} is not compatible with "
82 "{}, however, this error shall have been captured by "
83 "IsCompatibleBroadcastShape already.",
84 l_shape_filled[i], r_shape_filled[i]);
87 return broadcasted_shape;
102 int64_t src_ndims = src_shape.
size();
107 for (
const int64_t& dim : dims) {
108 out_shape[
WrapDim(dim, src_ndims)] = 1;
112 std::vector<bool> dims_mask(src_ndims,
false);
113 for (
const int64_t& dim : dims) {
114 if (dims_mask[
WrapDim(dim, src_ndims)]) {
117 dims_mask[
WrapDim(dim, src_ndims)] =
true;
120 for (int64_t i = 0; i < src_ndims; ++i) {
122 out_shape[to_fill] = out_shape[i];
126 out_shape.
resize(to_fill);
131 int64_t
WrapDim(int64_t dim, int64_t max_dim,
bool inclusive) {
135 int64_t
min = -max_dim;
136 int64_t
max = inclusive ? max_dim : max_dim - 1;
138 if (dim < min || dim >
max) {
140 "Index out-of-range: dim == {}, but it must satisfy {} <= dim "
152 int64_t new_size = 1;
153 bool has_inferred_dim =
false;
154 int64_t inferred_dim = 0;
155 for (int64_t dim = 0, ndim = shape.
size(); dim != ndim; dim++) {
156 if (shape[dim] == -1) {
157 if (has_inferred_dim) {
159 "Proposed shape {}, but at most one dimension can be "
164 has_inferred_dim =
true;
165 }
else if (shape[dim] >= 0) {
166 new_size *= shape[dim];
172 if (num_elements == new_size ||
173 (has_inferred_dim && new_size > 0 && num_elements % new_size == 0)) {
174 if (has_inferred_dim) {
185 "Cannot reshape tensor of 0 elements into shape {}, "
186 "because the unspecified dimension size -1 can be any "
187 "value and is ambiguous.",
190 inferred_shape[inferred_dim] = num_elements / new_size;
192 return inferred_shape;
216 int64_t stride_size = 1;
217 for (int64_t i = shape.
size(); i > 0; --i) {
218 strides[i - 1] = stride_size;
220 stride_size *= std::max<int64_t>(shape[i - 1], 1);
228 if (old_shape.
empty()) {
237 if (numel == 0 && old_shape == new_shape) {
243 for (int64_t view_d = new_shape.
size() - 1; view_d >= 0; view_d--) {
244 if (view_d == (int64_t)(new_shape.
size() - 1)) {
245 new_strides[view_d] = 1;
247 new_strides[view_d] =
248 std::max<int64_t>(new_shape[view_d + 1], 1) *
249 new_strides[view_d + 1];
255 int64_t view_d = new_shape.
size() - 1;
257 int64_t chunk_base_stride = old_strides.
back();
259 int64_t tensor_numel = 1;
260 int64_t view_numel = 1;
261 for (int64_t tensor_d = old_shape.
size() - 1; tensor_d >= 0; tensor_d--) {
262 tensor_numel *= old_shape[tensor_d];
264 if ((tensor_d == 0) ||
265 (old_shape[tensor_d - 1] != 1 &&
266 old_strides[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
267 while (view_d >= 0 &&
268 (view_numel < tensor_numel || new_shape[view_d] == 1)) {
269 new_strides[view_d] = view_numel * chunk_base_stride;
270 view_numel *= new_shape[view_d];
273 if (view_numel != tensor_numel) {
277 chunk_base_stride = old_strides[tensor_d - 1];
int64_t NumElements() const
std::string ToString() const
iterator insert(iterator I, T &&Elt)
SizeVector Concat(const SizeVector &l_shape, const SizeVector &r_shape)
Concatenate two shapes.
int64_t WrapDim(int64_t dim, int64_t max_dim, bool inclusive)
Wrap around negative dim.
bool CanBeBrocastedToShape(const SizeVector &src_shape, const SizeVector &dst_shape)
Returns true if src_shape can be brocasted to dst_shape.
SizeVector BroadcastedShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns the broadcasted shape of two shapes.
SizeVector ReductionShape(const SizeVector &src_shape, const SizeVector &dims, bool keepdim)
Returns the shape after reduction.
std::pair< bool, SizeVector > Restride(const SizeVector &old_shape, const SizeVector &old_strides, const SizeVector &new_shape)
SizeVector Iota(int64_t n)
Returns a SizeVector of {0, 1, ..., n - 1}, similar to std::iota.
SizeVector InferShape(SizeVector shape, int64_t num_elements)
SizeVector DefaultStrides(const SizeVector &shape)
Compute default strides for a shape when a tensor is contiguous.
bool IsCompatibleBroadcastShape(const SizeVector &l_shape, const SizeVector &r_shape)
Returns true if two shapes are compatible for broadcasting.
static SizeVector ExpandFrontDims(const SizeVector &shape, int64_t ndims)
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
Generic file read and write utility for python interface.