Skip to content

Commit

Permalink
Add submdspan_mapping for layout_right_padded
Browse files Browse the repository at this point in the history
  • Loading branch information
crtrott committed Jun 17, 2024
1 parent 2841f0c commit d18d407
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 16 deletions.
106 changes: 92 additions & 14 deletions include/experimental/__p2630_bits/submdspan_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ layout_left::mapping<Extents>::submdspan_mapping_impl(
#endif
}

// Actual submdspan mapping call
template <size_t PaddingValue>
template <class Extents>
template <class... SliceSpecifiers>
Expand Down Expand Up @@ -403,6 +402,21 @@ struct deduce_layout_right_submapping<
}
};

// We are reusing the same thing for layout_right and layout_right_padded
// For layout_right as source StaticStride is static_extent(Rank-1)
template<class Extents, size_t NumGaps, size_t StaticStride>
struct Compute_S_static_layout_right {
// Neither StaticStride nor any of the looked for extents can zero.
// StaticStride never can be zero, the static_extents we are looking at are associated with
// integral slice specifiers - which wouldn't be valid for zero extent
template<size_t ... Idx>
MDSPAN_INLINE_FUNCTION
static constexpr size_t value(std::index_sequence<Idx...>) {
size_t val = ((Idx >= Extents::rank() - 1 - NumGaps && Idx < Extents::rank() - 1 ? (Extents::static_extent(Idx) == dynamic_extent?0:Extents::static_extent(Idx)) : 1) * ... * (StaticStride == dynamic_extent?0:StaticStride));
return val == 0?dynamic_extent:val;
}
};

} // namespace detail

// Actual submdspan mapping call
Expand All @@ -423,14 +437,6 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
std::make_index_sequence<src_ext_t::rank()>,
SliceSpecifiers...>;

using dst_layout_t = std::conditional_t<
deduce_layout::layout_right_value(), layout_right,
std::conditional_t<
deduce_layout::layout_right_padded_value(),
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<dynamic_extent>,
layout_stride>>;
using dst_mapping_t = typename dst_layout_t::template mapping<dst_ext_t>;

// Figure out if any slice's lower bound equals the corresponding extent.
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
const bool out_of_bounds =
Expand All @@ -439,20 +445,21 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
out_of_bounds ? this->required_span_size()
: this->operator()(detail::first_of(slices)...));

if constexpr (std::is_same_v<dst_layout_t, layout_right>) {
if constexpr (deduce_layout::layout_right_value()) {
// layout_right case
using dst_mapping_t = typename layout_right::mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t(dst_ext),
offset};
} else if constexpr (std::is_same_v<
dst_layout_t,
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<
dynamic_extent>>) {
} else if constexpr (deduce_layout::layout_right_padded_value()) {
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::Compute_S_static_layout_left<Extents, deduce_layout::gap_len, Extents::static_extent(Extents::rank() - 1)>::value(std::make_index_sequence<Extents::rank()>());
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<S_static>::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext,
stride(src_ext_t::rank() - 2 - deduce_layout::gap_len)),
offset};
} else {
// layout_stride case
using dst_mapping_t = typename layout_stride::mapping<dst_ext_t>;
auto inv_map = detail::inv_map_rank(std::integral_constant<size_t, 0>(),
std::index_sequence<>(), slices...);
return submdspan_mapping_result<dst_mapping_t> {
Expand All @@ -477,6 +484,77 @@ layout_right::mapping<Extents>::submdspan_mapping_impl(
#endif
}

template <size_t PaddingValue>
template <class Extents>
template <class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION constexpr auto
MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<PaddingValue>::mapping<Extents>::submdspan_mapping_impl(
SliceSpecifiers... slices) const {

// compute sub extents
using src_ext_t = Extents;
auto dst_ext = submdspan_extents(extents(), slices...);
using dst_ext_t = decltype(dst_ext);

if constexpr (Extents::rank() == 0) { // rank-0 case
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<PaddingValue>::template mapping<Extents>;
return submdspan_mapping_result<dst_mapping_t>{*this, 0};
} else {
// Figure out if any slice's lower bound equals the corresponding extent.
// If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060.
// figure out sub layout type
const bool out_of_bounds =
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::any_slice_out_of_bounds(this->extents(), slices...);
auto offset = static_cast<size_t>(
out_of_bounds ? this->required_span_size()
: this->operator()(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::first_of(slices)...));
if constexpr (dst_ext_t::rank() == 0) { // result rank-0
using dst_mapping_t = typename layout_right::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
} else { // general case
using deduce_layout = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::deduce_layout_right_submapping<
typename dst_ext_t::index_type, dst_ext_t::rank(),
decltype(std::make_index_sequence<src_ext_t::rank()>()),
SliceSpecifiers...>;

if constexpr (deduce_layout::layout_right_value() && dst_ext_t::rank() == 1) { // getting rank-1 from rightmost
using dst_mapping_t = typename layout_right::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{dst_mapping_t{dst_ext}, offset};
} else if constexpr (deduce_layout::layout_right_padded_value()) { // can keep layout_right_padded
constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::Compute_S_static_layout_right<Extents, deduce_layout::gap_len, static_padding_stride>::value(std::make_index_sequence<Extents::rank()>());
using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_right_padded<S_static>::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t>{
dst_mapping_t(dst_ext, stride(Extents::rank() - 2 - deduce_layout::gap_len)), offset};
} else { // layout_stride
auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant<size_t, 0>(),
std::index_sequence<>(), slices...);
using dst_mapping_t = typename layout_stride::template mapping<dst_ext_t>;
return submdspan_mapping_result<dst_mapping_t> {
dst_mapping_t(dst_ext,
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides(
*this, inv_map,
// HIP needs deduction guides to have markups so we need to be explicit
// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have
// the issue But Clang-CUDA also doesn't accept the use of deduction guide so
// disable it for CUDA alltogether
#if defined(_MDSPAN_HAS_HIP) || defined(_MDSPAN_HAS_CUDA)
std::tuple<decltype(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices))...>{
MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
#else
std::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})),
#endif
offset
};
}
}
}


#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__)
__builtin_unreachable();
#endif
}

//**********************************
// layout_stride submdspan_mapping
//*********************************
Expand Down
13 changes: 13 additions & 0 deletions include/experimental/__p2642_bits/layout_padded.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,19 @@ class layout_right_padded<PaddingValue>::mapping {
return !(left == right);
}
#endif

// [mdspan.submdspan.mapping], submdspan mapping specialization
template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
constexpr auto submdspan_mapping_impl(
SliceSpecifiers... slices) const;

template<class... SliceSpecifiers>
MDSPAN_INLINE_FUNCTION
friend constexpr auto submdspan_mapping(
const mapping& src, SliceSpecifiers... slices) {
return src.submdspan_mapping_impl(slices...);
}
};
}
}
29 changes: 29 additions & 0 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ using submdspan_test_types =
, std::tuple<Kokkos::layout_left, Kokkos::layout_stride, Kokkos::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, Kokkos::extents<size_t,4,dyn,7>, int, Kokkos::full_extent_t, std::pair<int,int>, int, Kokkos::full_extent_t, int>
// layout_right to layout_right_padded
, std::tuple<Kokkos::layout_right, layout_right_padded<dyn>, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, layout_right_padded<20>, Kokkos::extents<size_t,dyn,20>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, layout_right_padded<dyn>, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, int, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, layout_right_padded<dyn>, Kokkos::dextents<size_t,4>, args_t<10,20,30,40>, Kokkos::dextents<size_t,3>, std::pair<int,int>, Kokkos::full_extent_t, int, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, layout_right_padded<dyn>, Kokkos::dextents<size_t,5>, args_t<10,20,30,40,50>, Kokkos::dextents<size_t,3>, int, std::pair<int,int>, Kokkos::full_extent_t, int, std::pair<int,int>>
Expand Down Expand Up @@ -206,6 +207,34 @@ using submdspan_test_types =
, std::tuple<layout_left_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,3>, Kokkos::full_extent_t, Kokkos::strided_slice<int,int,int>, Kokkos::full_extent_t>
, std::tuple<layout_left_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, int, Kokkos::strided_slice<int,int,int>>
, std::tuple<layout_left_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,4>, args_t<10,20,30,40>, Kokkos::dextents<size_t,3>, Kokkos::full_extent_t, Kokkos::full_extent_t, int, Kokkos::full_extent_t>
// layout_right_padded to layout_right
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_right, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,0>, int, int>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_right, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,1>, int, std::pair<int,int>>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_right, Kokkos::extents<size_t,dyn,30>, args_t<10,20>, Kokkos::extents<size_t,30>, int, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<4>, Kokkos::layout_right, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,1>, int, int, std::pair<int,int>>
, std::tuple<layout_right_padded<4>, Kokkos::layout_right, Kokkos::extents<size_t,dyn,dyn,30>, args_t<10,20,30>, Kokkos::extents<size_t,30>, int, int, Kokkos::full_extent_t>
// layout_right_padded to layout_right_padded
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,0>, args_t<>, Kokkos::dextents<size_t,0>>
, std::tuple<layout_right_padded<4>, layout_right_padded<4>, Kokkos::dextents<size_t,0>, args_t<>, Kokkos::dextents<size_t,0>>
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<4>, layout_right_padded<dyn>, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, std::pair<int, int>, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<4>, layout_right_padded<dyn>, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, std::pair<int, int>>
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, std::pair<int, int>, std::pair<int, int>>
, std::tuple<layout_right_padded<22>, layout_right_padded<22>, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t, 10, dyn>, Kokkos::full_extent_t, std::pair<int, int>>
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, int, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<4>, layout_right_padded<dyn>, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, std::pair<int, int>, int, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<4>, layout_right_padded<dyn>, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, int, std::pair<int, int>>
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, std::pair<int, int>, int, std::pair<int, int>>
, std::tuple<layout_right_padded<32>, layout_right_padded<640>, Kokkos::extents<size_t,dyn,20,30>, args_t<10,20,30>, Kokkos::extents<size_t, dyn, dyn>, Kokkos::full_extent_t, int, std::pair<int, int>>
, std::tuple<layout_right_padded<dyn>, layout_right_padded<dyn>, Kokkos::dextents<size_t,4>, args_t<10,20,30,40>, Kokkos::dextents<size_t,3>, Kokkos::full_extent_t, Kokkos::full_extent_t, int, Kokkos::full_extent_t>
// layout_right_padded to layout_stride
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,1>, args_t<10>, Kokkos::dextents<size_t,1>, Kokkos::strided_slice<int,int,int>>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,1>, int, Kokkos::strided_slice<int,int,int>>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::strided_slice<int,int,int>, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,2>, args_t<10,20>, Kokkos::dextents<size_t,2>, Kokkos::full_extent_t, Kokkos::strided_slice<int,int,int>>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,3>, Kokkos::full_extent_t, Kokkos::strided_slice<int,int,int>, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,3>, args_t<10,20,30>, Kokkos::dextents<size_t,2>, Kokkos::strided_slice<int,int,int>, int, Kokkos::full_extent_t>
, std::tuple<layout_right_padded<dyn>, Kokkos::layout_stride, Kokkos::dextents<size_t,4>, args_t<10,20,30,40>, Kokkos::dextents<size_t,3>, Kokkos::full_extent_t, int, Kokkos::full_extent_t, Kokkos::full_extent_t>
// Testing of customization point design
, std::tuple<Foo::layout_foo, Foo::layout_foo, Kokkos::dextents<size_t,1>, args_t<10>, Kokkos::dextents<size_t,1>, Kokkos::full_extent_t>
, std::tuple<Foo::layout_foo, Foo::layout_foo, Kokkos::dextents<size_t,1>, args_t<10>, Kokkos::dextents<size_t,1>, std::pair<int,int>>
Expand Down
4 changes: 2 additions & 2 deletions tests/test_submdspan_static_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ TEST(TestMdspan, SubmdspanStaticSlice_Right_i345_TupleFullTuple) {

{
using expected_extents_type = Kokkos::extents<int, Kokkos::dynamic_extent, 4, Kokkos::dynamic_extent>;
using expected_layout_type = Kokkos::Experimental::layout_right_padded<Kokkos::dynamic_extent>;
using expected_layout_type = Kokkos::Experimental::layout_right_padded<5>;
using expected_output_mdspan_type = Kokkos::mdspan<float, expected_extents_type, expected_layout_type>;

auto runTest = [&] (auto sliceSpec0, auto sliceSpec1) {
Expand All @@ -522,7 +522,7 @@ TEST(TestMdspan, SubmdspanStaticSlice_Right_i345_TupleFullTuple) {
}
{
using expected_extents_type = Kokkos::extents<int, 2, 4, 3>;
using expected_layout_type = Kokkos::Experimental::layout_right_padded<Kokkos::dynamic_extent>;
using expected_layout_type = Kokkos::Experimental::layout_right_padded<5>;
using expected_output_mdspan_type = Kokkos::mdspan<float, expected_extents_type, expected_layout_type>;

auto runTest = [&] (auto sliceSpec0, auto sliceSpec1) {
Expand Down

0 comments on commit d18d407

Please sign in to comment.