diff --git a/include/experimental/__p2630_bits/submdspan_mapping.hpp b/include/experimental/__p2630_bits/submdspan_mapping.hpp index cf1bdd1e..bad67328 100644 --- a/include/experimental/__p2630_bits/submdspan_mapping.hpp +++ b/include/experimental/__p2630_bits/submdspan_mapping.hpp @@ -182,6 +182,21 @@ struct deduce_layout_left_submapping< } }; +// We are reusing the same thing for layout_left and layout_left_padded +// For layout_left as source StaticStride is static_extent(0) +template +struct Compute_S_static_layout_left { + // Neither StaticStride nor any of the looked for extents can zero. + // StaticStride never can be zero, the static_extents we are looking at are associated with + // integral slice specifiers - which wouldn't be valid for zero extent + template + MDSPAN_INLINE_FUNCTION + static constexpr size_t value(std::index_sequence) { + size_t val = ((Idx>0 && Idx<=NumGaps ? (Extents::static_extent(Idx) == dynamic_extent?0:Extents::static_extent(Idx)) : 1) * ... * (StaticStride == dynamic_extent?0:StaticStride)); + return val == 0?dynamic_extent:val; + } +}; + } // namespace detail // Actual submdspan mapping call @@ -202,14 +217,6 @@ layout_left::mapping::submdspan_mapping_impl( std::make_index_sequence, SliceSpecifiers...>; - using dst_layout_t = std::conditional_t< - deduce_layout::layout_left_value(), layout_left, - std::conditional_t< - deduce_layout::layout_left_padded_value(), - MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded, - layout_stride>>; - using dst_mapping_t = typename dst_layout_t::template mapping; - // Figure out if any slice's lower bound equals the corresponding extent. // If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060. const bool out_of_bounds = @@ -218,17 +225,19 @@ layout_left::mapping::submdspan_mapping_impl( out_of_bounds ? this->required_span_size() : this->operator()(detail::first_of(slices)...)); - if constexpr (std::is_same_v) { + if constexpr (deduce_layout::layout_left_value()) { // layout_left case + using dst_mapping_t = typename layout_left::template mapping; return submdspan_mapping_result{dst_mapping_t(dst_ext), offset}; - } else if constexpr (std::is_same_v>) { + } else if constexpr (deduce_layout::layout_left_padded_value()) { + constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::Compute_S_static_layout_left::value(std::make_index_sequence()); + using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded::template mapping; return submdspan_mapping_result{ dst_mapping_t(dst_ext, stride(1 + deduce_layout::gap_len)), offset}; } else { // layout_stride case + using dst_mapping_t = typename layout_stride::mapping; auto inv_map = detail::inv_map_rank(std::integral_constant(), std::index_sequence<>(), slices...); return submdspan_mapping_result { @@ -253,6 +262,78 @@ layout_left::mapping::submdspan_mapping_impl( #endif } +// Actual submdspan mapping call +template +template +template +MDSPAN_INLINE_FUNCTION constexpr auto +MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded::mapping::submdspan_mapping_impl( + SliceSpecifiers... slices) const { + + // compute sub extents + using src_ext_t = Extents; + auto dst_ext = submdspan_extents(extents(), slices...); + using dst_ext_t = decltype(dst_ext); + + if constexpr (Extents::rank() == 0) { // rank-0 case + using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded::template mapping; + return submdspan_mapping_result{*this, 0}; + } else { + const bool out_of_bounds = + MDSPAN_IMPL_STANDARD_NAMESPACE::detail::any_slice_out_of_bounds(this->extents(), slices...); + auto offset = static_cast( + out_of_bounds ? this->required_span_size() + : this->operator()(MDSPAN_IMPL_STANDARD_NAMESPACE::detail::first_of(slices)...)); + if constexpr (dst_ext_t::rank() == 0) { // result rank-0 + using dst_mapping_t = typename layout_left::template mapping; + return submdspan_mapping_result{dst_mapping_t{dst_ext}, offset}; + } else { // general case + // Figure out if any slice's lower bound equals the corresponding extent. + // If so, bypass evaluating the layout mapping. This fixes LWG Issue 4060. + // figure out sub layout type + using deduce_layout = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::deduce_layout_left_submapping< + typename dst_ext_t::index_type, dst_ext_t::rank(), + decltype(std::make_index_sequence()), + SliceSpecifiers...>; + + if constexpr (deduce_layout::layout_left_value() && dst_ext_t::rank() == 1) { // getting rank-1 from leftmost + using dst_mapping_t = typename layout_left::template mapping; + return submdspan_mapping_result{dst_mapping_t{dst_ext}, offset}; + } else if constexpr (deduce_layout::layout_left_padded_value()) { // can keep layout_left_padded + constexpr size_t S_static = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::Compute_S_static_layout_left::value(std::make_index_sequence()); + using dst_mapping_t = typename MDSPAN_IMPL_PROPOSED_NAMESPACE::layout_left_padded::template mapping; + return submdspan_mapping_result{ + dst_mapping_t(dst_ext, stride(1 + deduce_layout::gap_len)), offset}; + } else { // layout_stride + auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant(), + std::index_sequence<>(), slices...); + using dst_mapping_t = typename layout_stride::template mapping; + return submdspan_mapping_result { + dst_mapping_t(dst_ext, + MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides( + *this, inv_map, +// HIP needs deduction guides to have markups so we need to be explicit +// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have +// the issue But Clang-CUDA also doesn't accept the use of deduction guide so +// disable it for CUDA alltogether +#if defined(_MDSPAN_HAS_HIP) || defined(_MDSPAN_HAS_CUDA) + std::tuple{ + MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})), +#else + std::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...})), +#endif + offset + }; + } + } + } + + +#if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__) + __builtin_unreachable(); +#endif +} + //********************************** // layout_right submdspan_mapping //********************************* diff --git a/include/experimental/__p2642_bits/layout_padded.hpp b/include/experimental/__p2642_bits/layout_padded.hpp index cd4bbcf5..372966cc 100644 --- a/include/experimental/__p2642_bits/layout_padded.hpp +++ b/include/experimental/__p2642_bits/layout_padded.hpp @@ -497,10 +497,12 @@ class layout_left_padded::mapping { // [mdspan.submdspan.mapping], submdspan mapping specialization template + MDSPAN_INLINE_FUNCTION constexpr auto submdspan_mapping_impl( SliceSpecifiers... slices) const; template + MDSPAN_INLINE_FUNCTION friend constexpr auto submdspan_mapping( const mapping& src, SliceSpecifiers... slices) { return src.submdspan_mapping_impl(slices...); diff --git a/tests/test_submdspan.cpp b/tests/test_submdspan.cpp index 28d0732d..2d763664 100644 --- a/tests/test_submdspan.cpp +++ b/tests/test_submdspan.cpp @@ -102,9 +102,14 @@ TEST(TestSubmdspanLayoutRightStaticSizedTuples, test_submdspan_layout_right_stat template using args_t = std::index_sequence; +template +using layout_left_padded = Kokkos::Experimental::layout_left_padded; +template +using layout_right_padded = Kokkos::Experimental::layout_right_padded; + using submdspan_test_types = ::testing::Types< - // LayoutLeft to LayoutLeft + // layout_left to layout_left std::tuple, args_t<10>, Kokkos::dextents, Kokkos::full_extent_t> , std::tuple, args_t<10>, Kokkos::dextents, std::pair> , std::tuple, args_t<10>, Kokkos::dextents, int> @@ -119,7 +124,7 @@ using submdspan_test_types = , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, Kokkos::full_extent_t, std::pair, int, int, int, int> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, Kokkos::full_extent_t, int, int, int ,int, int> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, std::pair, int, int, int, int, int> - // LayoutRight to LayoutRight + // layout_right to layout_right , std::tuple, args_t<10>, Kokkos::dextents, Kokkos::full_extent_t> , std::tuple, args_t<10>, Kokkos::dextents, std::pair> , std::tuple, args_t<10>, Kokkos::dextents, int> @@ -132,7 +137,7 @@ using submdspan_test_types = , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, int, int, int, std::pair, Kokkos::full_extent_t, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, int, int, int, int, std::pair, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::dextents, int, int, int, int, int, Kokkos::full_extent_t> - // LayoutRight to LayoutRight Check Extents Preservation + // layout_right to layout_right Check Extents Preservation , std::tuple, args_t<10>, Kokkos::extents, Kokkos::full_extent_t> , std::tuple, args_t<10>, Kokkos::extents, std::pair> , std::tuple, args_t<10>, Kokkos::extents, int> @@ -145,12 +150,13 @@ using submdspan_test_types = , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, int, int, int, std::pair, Kokkos::full_extent_t, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, int, int, int, int, std::pair, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, int, int, int, int, int, Kokkos::full_extent_t> - // LayoutLeft to layout_left_padded - , std::tuple, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, std::pair, Kokkos::full_extent_t> - , std::tuple, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t> - , std::tuple, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t, std::pair> - , std::tuple, Kokkos::dextents, args_t<10,20,30,40,50>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t, std::pair, int> - // LayoutLeft to LayoutStride + // layout_left to layout_left_padded + , std::tuple, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, std::pair, Kokkos::full_extent_t> + , std::tuple, Kokkos::extents, args_t<10,20>, Kokkos::dextents, std::pair, Kokkos::full_extent_t> + , std::tuple, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t> + , std::tuple, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t, std::pair> + , std::tuple, Kokkos::dextents, args_t<10,20,30,40,50>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t, std::pair, int> + // layout_left to layout_stride , std::tuple, args_t<10>, Kokkos::dextents, Kokkos::strided_slice> , std::tuple, args_t<10,20>, Kokkos::dextents, Kokkos::strided_slice, int> , std::tuple, args_t<10,20>, Kokkos::dextents, std::pair, Kokkos::strided_slice> @@ -159,10 +165,10 @@ using submdspan_test_types = , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, Kokkos::full_extent_t, int, std::pair, int, int, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, int, Kokkos::full_extent_t, std::pair, int, Kokkos::full_extent_t, int> // layout_right to layout_right_padded - , std::tuple, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::full_extent_t, std::pair> - , std::tuple, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, int, std::pair> - , std::tuple, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, std::pair, Kokkos::full_extent_t, int, std::pair> - , std::tuple, Kokkos::dextents, args_t<10,20,30,40,50>, Kokkos::dextents, int, std::pair, Kokkos::full_extent_t, int, std::pair> + , std::tuple, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::full_extent_t, std::pair> + , std::tuple, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, int, std::pair> + , std::tuple, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, std::pair, Kokkos::full_extent_t, int, std::pair> + , std::tuple, Kokkos::dextents, args_t<10,20,30,40,50>, Kokkos::dextents, int, std::pair, Kokkos::full_extent_t, int, std::pair> // layout_right to layout_stride , std::tuple, args_t<10>, Kokkos::dextents, Kokkos::strided_slice> , std::tuple, args_t<10>, Kokkos::extents, Kokkos::strided_slice,std::integral_constant>> @@ -172,6 +178,34 @@ using submdspan_test_types = , std::tuple, args_t<10,20>, Kokkos::dextents, Kokkos::strided_slice, Kokkos::strided_slice> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, Kokkos::full_extent_t, int, std::pair, int, int, Kokkos::full_extent_t> , std::tuple, args_t<6,4,5,6,7,8>, Kokkos::extents, int, Kokkos::full_extent_t, std::pair, int, Kokkos::full_extent_t, int> + // layout_left_padded to layout_left + , std::tuple, Kokkos::layout_left, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, int, int> + , std::tuple, Kokkos::layout_left, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, std::pair, int> + , std::tuple, Kokkos::layout_left, Kokkos::extents, args_t<10,20>, Kokkos::extents, Kokkos::full_extent_t, int> + , std::tuple, Kokkos::layout_left, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, std::pair, int, int> + , std::tuple, Kokkos::layout_left, Kokkos::extents, args_t<10,20,30>, Kokkos::extents, Kokkos::full_extent_t, int, int> + // layout_left_padded to layout_left_padded + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<>, Kokkos::dextents> + , std::tuple, layout_left_padded<4>, Kokkos::dextents, args_t<>, Kokkos::dextents> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::full_extent_t, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, std::pair, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::full_extent_t, std::pair> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, std::pair, std::pair> + , std::tuple, layout_left_padded<12>, Kokkos::extents, args_t<10,20>, Kokkos::extents, std::pair, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, int, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, std::pair, int, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, int, std::pair> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, std::pair, int, std::pair> + , std::tuple, layout_left_padded<240>, Kokkos::extents, args_t<10,20,30>, Kokkos::extents, std::pair, int, Kokkos::full_extent_t> + , std::tuple, layout_left_padded, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, Kokkos::full_extent_t, int, Kokkos::full_extent_t, Kokkos::full_extent_t> + // layout_left_padded to layout_stride + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10>, Kokkos::dextents, Kokkos::strided_slice> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::strided_slice, int> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::strided_slice, Kokkos::full_extent_t> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20>, Kokkos::dextents, Kokkos::full_extent_t, Kokkos::strided_slice> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, Kokkos::strided_slice, Kokkos::full_extent_t> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20,30>, Kokkos::dextents, Kokkos::full_extent_t, int, Kokkos::strided_slice> + , std::tuple, Kokkos::layout_stride, Kokkos::dextents, args_t<10,20,30,40>, Kokkos::dextents, Kokkos::full_extent_t, Kokkos::full_extent_t, int, Kokkos::full_extent_t> // Testing of customization point design , std::tuple, args_t<10>, Kokkos::dextents, Kokkos::full_extent_t> , std::tuple, args_t<10>, Kokkos::dextents, std::pair> diff --git a/tests/test_submdspan_static_slice.cpp b/tests/test_submdspan_static_slice.cpp index e7427e49..67c54807 100644 --- a/tests/test_submdspan_static_slice.cpp +++ b/tests/test_submdspan_static_slice.cpp @@ -202,7 +202,7 @@ TEST(TestMdspan, SubmdspanStaticSlice_Left_i345_FullIndexFull) { { using expected_extents_type = Kokkos::extents; - using expected_layout_type = Kokkos::Experimental::layout_left_padded; + using expected_layout_type = Kokkos::Experimental::layout_left_padded<12>; using expected_output_mdspan_type = Kokkos::mdspan; auto runTest = [&] (auto integralConstant) { @@ -424,7 +424,7 @@ TEST(TestMdspan, SubmdspanStaticSlice_Left_i345_TupleFullTuple) { { using expected_extents_type = Kokkos::extents; - using expected_layout_type = Kokkos::Experimental::layout_left_padded; + using expected_layout_type = Kokkos::Experimental::layout_left_padded<3>; using expected_output_mdspan_type = Kokkos::mdspan; auto runTest = [&] (auto sliceSpec0, auto sliceSpec1) { @@ -438,7 +438,7 @@ TEST(TestMdspan, SubmdspanStaticSlice_Left_i345_TupleFullTuple) { } { using expected_extents_type = Kokkos::extents; - using expected_layout_type = Kokkos::Experimental::layout_left_padded; + using expected_layout_type = Kokkos::Experimental::layout_left_padded<3>; using expected_output_mdspan_type = Kokkos::mdspan; auto runTest = [&] (auto sliceSpec0, auto sliceSpec1) {