27 if (dims.
size() != 1) {
28 std::vector<bool> seen_dims(src.
NumDims(),
false);
29 for (
const int64_t& dim : dims) {
30 seen_dims[dim] =
true;
32 if (!std::all_of(seen_dims.begin(), seen_dims.end(),
33 [](
bool seen) { return seen; })) {
35 "Arg-reduction can only have 1 or all reduction "
36 "dimensions. However, dims = {}.",
42 "Zero-size Tensor does not support Arg Reductions.");
50 if (keepdim && keepdim_shape != dst.
GetShape()) {
54 if (!keepdim && non_keepdim_shape != dst.
GetShape()) {
60 if (dims.
size() == 0) {
67 dst = dst.
Reshape(keepdim_shape);
79 #ifdef BUILD_SYCL_MODULE
85 #ifdef BUILD_CUDA_MODULE
86 ReductionCUDA(src, dst, dims, keepdim, op_code);
95 dst = dst.
Reshape(non_keepdim_shape);
std::string ToString() const
Returns string representation of device, e.g. "CPU:0", "CUDA:0".
std::string ToString() const
int64_t NumElements() const
Device GetDevice() const override
Tensor Reshape(const SizeVector &dst_shape) const
SizeVector GetShape() const
void ReductionSYCL(const Tensor &src, Tensor &dst, const SizeVector &dims, bool keepdim, ReductionOpCode op_code)
static const std::unordered_set< ReductionOpCode, utility::hash_enum_class > s_arg_reduce_ops
void ReductionCPU(const Tensor &src, Tensor &dst, const SizeVector &dims, bool keepdim, ReductionOpCode op_code)
void Reduction(const Tensor &src, Tensor &dst, const SizeVector &dims, bool keepdim, ReductionOpCode op_code)
SizeVector ReductionShape(const SizeVector &src_shape, const SizeVector &dims, bool keepdim)
Returns the shape after reduction.
Generic file read and write utility for python interface.