25 template <
typename scalar_t>
26 struct ArgMinReduction {
27 using basic_reduction = sycl::minimum<scalar_t>;
28 std::pair<int64_t, scalar_t> operator()(int64_t a_idx,
31 scalar_t b_val)
const {
37 template <
typename scalar_t>
38 struct ArgMaxReduction {
39 using basic_reduction = sycl::maximum<scalar_t>;
40 std::pair<int64_t, scalar_t> operator()(int64_t a_idx,
43 scalar_t b_val)
const {
53 template <
class ReductionOp,
typename scalar_t>
54 void SYCLReductionEngine(Device device, Indexer
indexer, scalar_t identity) {
57 auto queue = device_props.
queue;
58 auto work_group_size = device_props.max_work_group_size;
59 size_t log2elements_per_group = 13;
60 auto elements_per_group = (1 << log2elements_per_group);
61 size_t log2workitems_per_group = 8;
62 auto workitems_per_group = (1 << log2workitems_per_group);
63 auto elements_per_work_item =
64 elements_per_group / workitems_per_group;
65 auto mask = ~(~0 << log2workitems_per_group);
68 for (int64_t output_idx = 0; output_idx <
indexer.NumOutputElements();
72 Indexer scalar_out_indexer =
indexer.GetPerOutputIndexer(output_idx);
73 auto num_elements = scalar_out_indexer.NumWorkloads();
74 auto num_work_groups = num_elements / elements_per_group;
75 if (num_elements > elements_per_group * num_work_groups)
78 auto num_work_items = num_work_groups * work_group_size;
80 auto red_cg = [&](
auto& cgh) {
81 auto output = scalar_out_indexer.GetOutputPtr<scalar_t>(0);
84 auto sycl_reducer = sycl::reduction(
85 output, identity, red_op,
86 {sycl::property::reduction::initialize_to_identity()});
88 sycl::nd_range<1>{num_work_items, work_group_size},
89 sycl_reducer, [=](sycl::nd_item<1> item,
auto& red_arg) {
90 auto glob_id = item.get_global_id(0);
91 auto offset = ((glob_id >> log2workitems_per_group)
92 << log2elements_per_group) +
94 auto item_out = identity;
95 for (
size_t i = 0; i < elements_per_work_item; i++) {
97 (i << log2workitems_per_group) +
offset;
98 if (idx >= num_elements)
break;
100 *scalar_out_indexer.GetInputPtr<scalar_t>(
102 item_out = red_op(item_out, val);
104 red_arg.combine(item_out);
108 auto e = queue.submit(red_cg);
110 queue.wait_and_throw();
118 template <
class ReductionOp,
typename scalar_t>
119 void SYCLArgReductionEngine(Device device, Indexer
indexer, scalar_t identity) {
122 auto queue = device_props.
queue;
123 auto work_group_size = device_props.max_work_group_size;
124 size_t log2elements_per_group = 13;
125 auto elements_per_group = (1 << log2elements_per_group);
126 size_t log2workitems_per_group = 8;
127 auto workitems_per_group = (1 << log2workitems_per_group);
128 auto elements_per_work_item =
129 elements_per_group / workitems_per_group;
130 auto mask = ~(~0 << log2workitems_per_group);
134 sycl::buffer<int32_t, 1> output_in_use{
indexer.NumOutputElements()};
135 auto e_fill = queue.submit([&](
auto& cgh) {
136 auto acc_output_in_use =
137 output_in_use.get_access<sycl::access_mode::write>(cgh);
138 cgh.fill(acc_output_in_use, 0);
141 for (int64_t output_idx = 0; output_idx <
indexer.NumOutputElements();
145 Indexer scalar_out_indexer =
indexer.GetPerOutputIndexer(output_idx);
146 auto num_elements = scalar_out_indexer.NumWorkloads();
147 auto num_work_groups = num_elements / elements_per_group;
148 if (num_elements > elements_per_group * num_work_groups)
151 auto num_work_items = num_work_groups * work_group_size;
153 sycl::buffer<int32_t, 1> this_output_in_use{output_in_use, output_idx,
155 auto arg_red_cg = [&](
auto& cgh) {
158 .get_access<sycl::access_mode::read_write>(cgh);
160 sycl::nd_range<1>{num_work_items, work_group_size},
161 [=](sycl::nd_item<1> item) {
163 *scalar_out_indexer.GetOutputPtr<int64_t>(0, 0);
165 *scalar_out_indexer.GetOutputPtr<scalar_t>(1,
167 auto glob_id = item.get_global_id(0);
168 auto this_group = item.get_group();
169 auto offset = ((glob_id >> log2workitems_per_group)
170 << log2elements_per_group) +
173 scalar_t it_val = identity;
174 for (
size_t i = 0; i < elements_per_work_item; i++) {
176 (i << log2workitems_per_group) +
offset;
177 if (idx >= num_elements)
break;
179 *scalar_out_indexer.GetInputPtr<scalar_t>(
181 std::tie(it_idx, it_val) =
182 red_op(it_idx, it_val, idx, val);
184 auto group_out_val = sycl::reduce_over_group(
185 this_group, it_val, identity,
186 typename ReductionOp::basic_reduction());
190 if (it_val == group_out_val) {
193 auto in_use = sycl::atomic_ref<
194 int32_t, sycl::memory_order::acq_rel,
195 sycl::memory_scope::device>(acc_in_use[0]);
196 while (in_use.exchange(1) == 1) {
198 std::tie(out_idx, out_val) = red_op(
199 out_idx, out_val, it_idx, group_out_val);
205 auto e = queue.submit(arg_red_cg);
207 queue.wait_and_throw();
222 case ReductionOpCode::Sum:
224 SYCLReductionEngine<sycl::plus<scalar_t>, scalar_t>(
227 case ReductionOpCode::Prod:
229 SYCLReductionEngine<sycl::multiplies<scalar_t>, scalar_t>(
232 case ReductionOpCode::Min:
233 if (indexer.NumWorkloads() == 0) {
235 "Zero-size Tensor does not support Min.");
237 identity = std::numeric_limits<scalar_t>::max();
239 SYCLReductionEngine<sycl::minimum<scalar_t>, scalar_t>(
240 device, indexer, identity);
243 case ReductionOpCode::Max:
244 if (indexer.NumWorkloads() == 0) {
246 "Zero-size Tensor does not support Max.");
248 identity = std::numeric_limits<scalar_t>::lowest();
250 SYCLReductionEngine<sycl::maximum<scalar_t>, scalar_t>(
251 device, indexer, identity);
255 utility::LogError(
"Unsupported op code.");
264 Tensor dst_acc(dst.GetShape(), src.GetDtype(), src.GetDevice());
269 case ReductionOpCode::ArgMin:
270 identity = std::numeric_limits<scalar_t>::max();
271 dst_acc.Fill(identity);
272 SYCLArgReductionEngine<ArgMinReduction<scalar_t>, scalar_t>(
273 device, indexer, identity);
275 case ReductionOpCode::ArgMax:
276 identity = std::numeric_limits<scalar_t>::lowest();
277 dst_acc.Fill(identity);
278 SYCLArgReductionEngine<ArgMaxReduction<scalar_t>, scalar_t>(
279 device, indexer, identity);
282 utility::LogError(
"Unsupported op code.");
290 "Boolean reduction only supports boolean input tensor.");
294 "Boolean reduction only supports boolean output tensor.");
298 case ReductionOpCode::All:
301 SYCLReductionEngine<sycl::logical_and<bool>,
bool>(
304 case ReductionOpCode::Any:
307 SYCLReductionEngine<sycl::logical_or<bool>,
bool>(
#define DISPATCH_DTYPE_TO_TEMPLATE(DTYPE,...)
Device GetDevice() const override
static SYCLContext & GetInstance()
Get singleton instance.
SYCLDevice GetDeviceProperties(const Device &device)
Get SYCL device properties given an CloudViewer device.
void ReductionSYCL(const Tensor &src, Tensor &dst, const SizeVector &dims, bool keepdim, ReductionOpCode op_code)
static const std::unordered_set< ReductionOpCode, utility::hash_enum_class > s_arg_reduce_ops
static const std::unordered_set< ReductionOpCode, utility::hash_enum_class > s_regular_reduce_ops
static const std::unordered_set< ReductionOpCode, utility::hash_enum_class > s_boolean_reduce_ops
CLOUDVIEWER_HOST_DEVICE Pair< First, Second > make_pair(const First &_first, const Second &_second)
Generic file read and write utility for python interface.
sycl::queue queue
Default queue for this device.