18 template <
typename func_t>
24 const func_t& element_kernel) {
35 int64_t broadcasting_elems = 1;
36 for (int64_t d = 1; d < src.
NumDims(); ++d) {
37 broadcasting_elems *= src.
GetShape(d);
39 auto element_func = [=](int64_t workload_idx) {
40 int reduction_idx = workload_idx / broadcasting_elems;
41 int broadcasting_idx = workload_idx % broadcasting_elems;
43 const int64_t idx = index_ptr[reduction_idx];
44 int64_t dst_idx = idx * broadcasting_elems + broadcasting_idx;
46 void* src_ptr =
indexer.GetOutputPtr(0, workload_idx);
47 void* dst_ptr =
indexer.GetInputPtr(0, dst_idx);
49 element_kernel(src_ptr, dst_ptr);
54 for (int64_t d = 0; d <
indexer.NumWorkloads(); ++d) {
59 template <
typename scalar_t>
61 scalar_t* dst_s_ptr =
static_cast<scalar_t*
>(dst);
62 const scalar_t* src_s_ptr =
static_cast<const scalar_t*
>(src);
63 *dst_s_ptr += *src_s_ptr;
71 LaunchIndexReductionKernel(dim, src.GetDevice(), index, src, dst,
72 [](const void* src, void* dst) {
73 CPUSumKernel<scalar_t>(src, dst);
#define CLOUDVIEWER_HOST_DEVICE
#define DISPATCH_FLOAT_DTYPE_TO_TEMPLATE(DTYPE,...)
SizeVector GetShape() const
static CLOUDVIEWER_HOST_DEVICE void CPUSumKernel(const void *src, void *dst)
void LaunchIndexReductionKernel(int64_t dim, const Device &device, const Tensor &index, const Tensor &src, Tensor &dst, const func_t &element_kernel)
void IndexAddCPU_(int64_t dim, const Tensor &index, const Tensor &src, Tensor &dst)
Generic file read and write utility for python interface.