Skip to content

Commit

Permalink
[oneDPL][ranges] + support sized output range for copy_if; dpcpp back…
Browse files Browse the repository at this point in the history
…end, part 2
  • Loading branch information
MikeDvorskiy committed Jan 16, 2025
1 parent 3e86e51 commit 4eddf48
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ __pattern_copy_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterato
if (__first == __last)
return __result_first;

auto __n = __last - __first;
auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator1>();
auto __buf1 = __keep1(__first, __last);
auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _Iterator2>();
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,18 @@ std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__inter
__pattern_copy_if(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
_Predicate __pred, _Assign __assign)
{
using _Index = oneapi::dpl::__internal::__difference_t<_Range2>;
using _Index = std::size_t; //TODO
_Index __n = __rng1.size();
if (__n == 0 || __rng2.empty())
return {0, 0};

auto __res = oneapi::dpl::__par_backend_hetero::__parallel_copy_if_out_lim(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1),
std::forward<_Range2>(__rng2), __pred, __assign).get();
std::forward<_Range2>(__rng2), __pred, __assign);

std::array<_Index, _2> __idx;
std::array<_Index, 2> __idx;
__res.get_values(__idx); //a blocking call
return {__idx[0], __idx[1];
return {__idx[1], __idx[0]}; //__parallel_copy_if_out_lim returns {last index in output, last index in input}
}

#if _ONEDPL_CPP20_RANGES_PRESENT
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
// Storage for the results of scan for each workgroup

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Type>;
__result_and_scratch_storage_t __result_and_scratch{__exec, 1, __n_groups + 1};
__result_and_scratch_storage_t __result_and_scratch{__exec, 2, __n_groups + 1};

_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);

Expand Down Expand Up @@ -1235,6 +1235,7 @@ __parallel_scan_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag
_InRng&& __in_rng, _OutRng&& __out_rng, _CreateMaskOp __create_mask_op,
_CopyByMaskOp __copy_by_mask_op)
{
using _Size = decltype(__out_rng.size());
using _ReduceOp = std::plus<_Size>;
using _Assigner = unseq_backend::__scan_assigner;
using _NoAssign = unseq_backend::__scan_no_assign;
Expand Down Expand Up @@ -1370,6 +1371,7 @@ auto
__parallel_copy_if_out_lim(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
_InRng&& __in_rng, _OutRng&& __out_rng, _Pred __pred, _Assign __assign = _Assign{})
{
using _Size = decltype(__out_rng.size());
using _ReduceOp = std::plus<_Size>;
using _CreateOp = unseq_backend::__create_mask<_Pred, _Size>;
using _CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, _Assign,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <tuple>
#include <algorithm>

#include "../../iterator_impl.h"
#include "../../iterator_impl.h"

#include "sycl_defs.h"
#include "execution_sycl_defs.h"
Expand Down Expand Up @@ -683,8 +683,8 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
}
}

template <typename _T, std::size_t _N>
void get_values(std::array<_T, _N>& __arr)
template <std::size_t _N>
void get_values(std::array<_T, _N>& __arr) const
{
assert(__result_n > 0);
assert(_N == __result_n);
Expand Down Expand Up @@ -713,14 +713,14 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
return __get_value(idx);
}

template <typename _Event, typename _T, std::size_t _N>
template <typename _Event, std::size_t _N>
void
__wait_and_get_value(_Event&& __event, std::array<_T, _N>& __arr) const
{
if (is_USM())
__event.wait_and_throw();

return get_values(__arr);
get_values(__arr);
}
};

Expand All @@ -745,7 +745,7 @@ struct __wait_and_get_value
constexpr void
operator()(auto&& __event, const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage, std::array<_T, _N>& __arr)
{
return __storage.__wait_and_get_value(__event, __arr);
__storage.__wait_and_get_value(__event, __arr);
}

template <typename _T>
Expand Down Expand Up @@ -812,9 +812,11 @@ class __future : private std::tuple<_Args...>
}

template <typename _T, std::size_t _N>
std::enable_if_t<sizeof...(_Args) > 0>
void
get_values(std::array<_T, _N>& __arr)
{
static_assert(sizeof...(_Args) > 0);
auto& __val = std::get<0>(*this);
__wait_and_get_value{}(event(), __val, __arr);
}

Expand Down
13 changes: 12 additions & 1 deletion include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,23 @@ struct __copy_by_mask
// ::std::tuple as operands, in all the other cases this is not necessary and no conversion
// is performed(i.e. __typle_type is the same type as its operand).
if(__out_idx < __out_acc.size())
{
__assigner(static_cast<__tuple_type>(get<0>(__in_acc[__item_idx])), __out_acc[__out_idx]);
auto __last_out_idx = __wg_sums_ptr[(__n - 1) / __size_per_wg];
if(__out_idx + 1 == __last_out_idx)
{
__ret_ptr[0] = __item_idx + 1, __ret_ptr[1] = __last_out_idx;
}
}
else if(__out_idx == __out_acc.size())
{
__ret_ptr[0] = __item_idx, __ret_ptr[1] = __out_idx;
}
}
if (__item_idx == 0)
{
//copy final result to output
*__ret_ptr = __wg_sums_ptr[(__n - 1) / __size_per_wg];
__ret_ptr[1] = __wg_sums_ptr[(__n - 1) / __size_per_wg];
}
}
};
Expand Down

0 comments on commit 4eddf48

Please sign in to comment.