Skip to content

Commit

Permalink
Remove the variadic templates for batch dimensions
Browse files Browse the repository at this point in the history
- The use of variadic templates for the batch dimensions caused issue
  when compiling with nvcc
- Replaced the variadic template on the dimensions by a single template
  argument for the corresponding DiscreteDomain
  • Loading branch information
tretre91 committed Feb 28, 2025
1 parent 18a10a5 commit 0e71b3e
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 353 deletions.
143 changes: 72 additions & 71 deletions include/ddc/kernels/splines/spline_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <optional>
#include <stdexcept>
#include <tuple>
#include <type_traits>
#include <utility>

#include <ddc/ddc.hpp>
Expand Down Expand Up @@ -88,64 +89,70 @@ class SplineBuilder
/// @brief The type of the domain for the 1D interpolation mesh used by this class.
using interpolation_domain_type = ddc::DiscreteDomain<interpolation_discrete_dimension_type>;

/// @brief The type of the whole domain representing interpolation points.
template <typename... DDimX>
using batched_interpolation_domain_type = ddc::DiscreteDomain<DDimX...>;
/**
* @brief The type of the whole domain representing interpolation points.
*
* @tparam The batched discrete domain on which the interpolation points are defined.
*/
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
using batched_interpolation_domain_type = DDom;

/**
* @brief The type of the batch domain (obtained by removing the dimension of interest
* from the whole domain).
*
* @tparam The batched discrete domain on which the interpolation points are defined.
*
* Example: For batched_interpolation_domain_type = DiscreteDomain<X,Y,Z> and a dimension of interest Y,
* this is DiscreteDomain<X,Z>
*/
template <typename... DDimX>
using batch_domain_type = ddc::remove_dims_of_t<
batched_interpolation_domain_type<DDimX...>,
interpolation_discrete_dimension_type>;
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
using batch_domain_type = ddc::remove_dims_of_t<DDom, interpolation_discrete_dimension_type>;

/**
* @brief The type of the whole spline domain (cartesian product of 1D spline domain
* and batch domain) preserving the underlying memory layout (order of dimensions).
*
* @tparam The batched discrete domain on which the interpolation points are defined.
*
* Example: For batched_interpolation_domain_type = DiscreteDomain<X,Y,Z> and a dimension of interest Y
* (associated to a B-splines tag BSplinesY), this is DiscreteDomain<X,BSplinesY,Z>.
*/
template <typename... DDimX>
using batched_spline_domain_type = ddc::replace_dim_of_t<
batched_interpolation_domain_type<DDimX...>,
interpolation_discrete_dimension_type,
bsplines_type>;
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
using batched_spline_domain_type
= ddc::replace_dim_of_t<DDom, interpolation_discrete_dimension_type, bsplines_type>;

private:
/**
* @brief The type of the whole spline domain (cartesian product of the 1D spline domain
* and the batch domain) with 1D spline dimension being the leading dimension.
*
* @tparam The batched discrete domain on which the interpolation points are defined.
*
* Example: For batched_interpolation_domain_type = DiscreteDomain<X,Y,Z> and a dimension of interest Y
* (associated to a B-splines tag BSplinesY), this is DiscreteDomain<BSplinesY,X,Z>.
*/
template <typename... DDimX>
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
using batched_spline_tr_domain_type =
typename ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_merge_t<
ddc::detail::TypeSeq<bsplines_type>,
ddc::type_seq_remove_t<
ddc::detail::TypeSeq<DDimX...>,
ddc::to_type_seq_t<DDom>,
ddc::detail::TypeSeq<interpolation_discrete_dimension_type>>>>;

public:
/**
* @brief The type of the whole Deriv domain (cartesian product of 1D Deriv domain
* and batch domain) preserving the underlying memory layout (order of dimensions).
*
* @tparam The batched discrete domain on which the interpolation points are defined.
*
* Example: For batched_interpolation_domain_type = DiscreteDomain<X,Y,Z> and a dimension of interest Y,
* this is DiscreteDomain<X,Deriv<Y>,Z>
*/
template <typename... DDimX>
using batched_derivs_domain_type = ddc::replace_dim_of_t<
batched_interpolation_domain_type<DDimX...>,
interpolation_discrete_dimension_type,
deriv_type>;
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
using batched_derivs_domain_type
= ddc::replace_dim_of_t<DDom, interpolation_discrete_dimension_type, deriv_type>;

/// @brief Indicates if the degree of the splines is odd or even.
static constexpr bool s_odd = BSplines::degree() % 2;
Expand Down Expand Up @@ -243,9 +250,9 @@ class SplineBuilder
*
* @see MatrixSparse
*/
template <class... DDimX>
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
explicit SplineBuilder(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain,
DDom const& batched_interpolation_domain,
std::optional<std::size_t> cols_per_chunk = std::nullopt,
std::optional<unsigned int> preconditioner_max_block_size = std::nullopt)
: SplineBuilder(
Expand Down Expand Up @@ -296,12 +303,12 @@ class SplineBuilder
* Values of the function must be provided on this domain in order
* to build a spline representation of the function (cartesian product of 1D interpolation_domain and batch_domain).
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The domain for the interpolation mesh.
*/
template <class... DDimX>
batched_interpolation_domain_type<DDimX...> batched_interpolation_domain(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain)
const noexcept
template <class DDom, class = std::enable_if_t<ddc::is_discrete_domain_v<DDom>>>
DDom batched_interpolation_domain(DDom const& batched_interpolation_domain) const noexcept
{
return batched_interpolation_domain;
}
Expand All @@ -311,11 +318,12 @@ class SplineBuilder
*
* Obtained by removing the dimension of interest from the whole interpolation domain.
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The batch domain.
*/
template <class... DDimX>
batch_domain_type<DDimX...> batch_domain(batched_interpolation_domain_type<DDimX...> const&
batched_interpolation_domain) const noexcept
template <class DDom>
batch_domain_type<DDom> batch_domain(DDom const& batched_interpolation_domain) const noexcept
{
return ddc::remove_dims_of(batched_interpolation_domain, interpolation_domain());
}
Expand All @@ -337,12 +345,13 @@ class SplineBuilder
*
* Spline approximations (spline-transformed functions) are computed on this domain.
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The domain for the spline coefficients.
*/
template <class... DDimX>
batched_spline_domain_type<DDimX...> batched_spline_domain(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain)
const noexcept
template <class DDom>
batched_spline_domain_type<DDom> batched_spline_domain(
DDom const& batched_interpolation_domain) const noexcept
{
return ddc::replace_dim_of<
interpolation_discrete_dimension_type,
Expand All @@ -355,14 +364,15 @@ class SplineBuilder
*
* This is used internally due to solver limitation and because it may be beneficial to computation performance. For LAPACK backend and non-periodic boundary condition, we are using SplinesLinearSolver3x3Blocks which requires upper_block_size additional rows for internal operations.
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The (transposed) domain for the spline coefficients.
*/
template <class... DDimX>
batched_spline_tr_domain_type<DDimX...> batched_spline_tr_domain(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain)
const noexcept
template <class DDom>
batched_spline_tr_domain_type<DDom> batched_spline_tr_domain(
DDom const& batched_interpolation_domain) const noexcept
{
return batched_spline_tr_domain_type<DDimX...>(
return batched_spline_tr_domain_type<DDom>(
ddc::replace_dim_of<bsplines_type, bsplines_type>(
batched_spline_domain(batched_interpolation_domain),
ddc::DiscreteDomain<bsplines_type>(
Expand All @@ -377,12 +387,13 @@ class SplineBuilder
*
* This is only used with BoundCond::HERMITE boundary conditions.
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The domain for the Derivs values.
*/
template <class... DDimX>
batched_derivs_domain_type<DDimX...> batched_derivs_xmin_domain(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain)
const noexcept
template <class DDom>
batched_derivs_domain_type<DDom> batched_derivs_xmin_domain(
DDom const& batched_interpolation_domain) const noexcept
{
return ddc::replace_dim_of<interpolation_discrete_dimension_type, deriv_type>(
batched_interpolation_domain,
Expand All @@ -396,12 +407,13 @@ class SplineBuilder
*
* This is only used with BoundCond::HERMITE boundary conditions.
*
* @param batched_interpolation_domain The whole domain on which the interpolation points are defined.
*
* @return The domain for the Derivs values.
*/
template <class... DDimX>
batched_derivs_domain_type<DDimX...> batched_derivs_xmax_domain(
batched_interpolation_domain_type<DDimX...> const& batched_interpolation_domain)
const noexcept
template <class DDom>
batched_derivs_domain_type<DDom> batched_derivs_xmax_domain(
DDom const& batched_interpolation_domain) const noexcept
{
return ddc::replace_dim_of<interpolation_discrete_dimension_type, deriv_type>(
batched_interpolation_domain,
Expand Down Expand Up @@ -429,24 +441,19 @@ class SplineBuilder
* @param[in] derivs_xmax The values of the derivatives at the upper boundary
* (used only with BoundCond::HERMITE upper boundary condition).
*/
template <class Layout, class... DDimX>
template <class Layout, class DDom>
void operator()(
ddc::ChunkSpan<double, batched_spline_domain_type<DDimX...>, Layout, memory_space>
spline,
ddc::ChunkSpan<
double const,
batched_interpolation_domain_type<DDimX...>,
Layout,
memory_space> vals,
ddc::ChunkSpan<double, batched_spline_domain_type<DDom>, Layout, memory_space> spline,
ddc::ChunkSpan<double const, DDom, Layout, memory_space> vals,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type<DDimX...>,
batched_derivs_domain_type<DDom>,
Layout,
memory_space>> derivs_xmin
= std::nullopt,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type<DDimX...>,
batched_derivs_domain_type<DDom>,
Layout,
memory_space>> derivs_xmax
= std::nullopt) const;
Expand Down Expand Up @@ -755,23 +762,19 @@ template <
ddc::BoundCond BcLower,
ddc::BoundCond BcUpper,
SplineSolver Solver>
template <class Layout, class... DDimX>
template <class Layout, class DDom>
void SplineBuilder<ExecSpace, MemorySpace, BSplines, InterpolationDDim, BcLower, BcUpper, Solver>::
operator()(
ddc::ChunkSpan<double, batched_spline_domain_type<DDimX...>, Layout, memory_space> spline,
ddc::ChunkSpan<
double const,
batched_interpolation_domain_type<DDimX...>,
Layout,
memory_space> vals,
ddc::ChunkSpan<double, batched_spline_domain_type<DDom>, Layout, memory_space> spline,
ddc::ChunkSpan<double const, DDom, Layout, memory_space> vals,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type<DDimX...>,
batched_derivs_domain_type<DDom>,
Layout,
memory_space>> const derivs_xmin,
std::optional<ddc::ChunkSpan<
double const,
batched_derivs_domain_type<DDimX...>,
batched_derivs_domain_type<DDom>,
Layout,
memory_space>> const derivs_xmax) const
{
Expand All @@ -793,8 +796,6 @@ operator()(
assert(ddc::DiscreteElement<deriv_type>(derivs_xmax->domain().front()).uid() == 1);
}

using batch_domain_type = batch_domain_type<DDimX...>;

// Hermite boundary conditions at xmin, if any
// NOTE: For consistency with the linear system, the i-th derivative
// provided by the user must be multiplied by dx^i
Expand All @@ -806,7 +807,7 @@ operator()(
"ddc_splines_hermite_compute_lower_coefficients",
exec_space(),
batch_domain(batched_interpolation_domain),
KOKKOS_LAMBDA(typename batch_domain_type::discrete_element_type j) {
KOKKOS_LAMBDA(typename batch_domain_type<DDom>::discrete_element_type j) {
for (int i = s_nbc_xmin; i > 0; --i) {
spline(ddc::DiscreteElement<bsplines_type>(s_nbc_xmin - i), j)
= derivs_xmin_values(ddc::DiscreteElement<deriv_type>(i), j)
Expand Down Expand Up @@ -849,7 +850,7 @@ operator()(
"ddc_splines_hermite_compute_upper_coefficients",
exec_space(),
batch_domain(batched_interpolation_domain),
KOKKOS_LAMBDA(typename batch_domain_type::discrete_element_type j) {
KOKKOS_LAMBDA(typename batch_domain_type<DDom>::discrete_element_type j) {
for (int i = 0; i < s_nbc_xmax; ++i) {
spline(ddc::DiscreteElement<bsplines_type>(nbasis_proxy - s_nbc_xmax - i),
j)
Expand All @@ -869,7 +870,7 @@ operator()(
"ddc_splines_transpose_rhs",
exec_space(),
batch_domain(batched_interpolation_domain),
KOKKOS_LAMBDA(typename batch_domain_type::discrete_element_type const j) {
KOKKOS_LAMBDA(typename batch_domain_type<DDom>::discrete_element_type const j) {
for (std::size_t i = 0; i < nbasis_proxy; ++i) {
spline_tr(ddc::DiscreteElement<bsplines_type>(i), j)
= spline(ddc::DiscreteElement<bsplines_type>(i + offset_proxy), j);
Expand All @@ -887,7 +888,7 @@ operator()(
"ddc_splines_transpose_back_rhs",
exec_space(),
batch_domain(batched_interpolation_domain),
KOKKOS_LAMBDA(typename batch_domain_type::discrete_element_type const j) {
KOKKOS_LAMBDA(typename batch_domain_type<DDom>::discrete_element_type const j) {
for (std::size_t i = 0; i < nbasis_proxy; ++i) {
spline(ddc::DiscreteElement<bsplines_type>(i + offset_proxy), j)
= spline_tr(ddc::DiscreteElement<bsplines_type>(i), j);
Expand All @@ -900,7 +901,7 @@ operator()(
"ddc_splines_periodic_rows_duplicate_rhs",
exec_space(),
batch_domain(batched_interpolation_domain),
KOKKOS_LAMBDA(typename batch_domain_type::discrete_element_type const j) {
KOKKOS_LAMBDA(typename batch_domain_type<DDom>::discrete_element_type const j) {
if (offset_proxy != 0) {
for (int i = 0; i < offset_proxy; ++i) {
spline(ddc::DiscreteElement<bsplines_type>(i), j) = spline(
Expand Down
Loading

0 comments on commit 0e71b3e

Please sign in to comment.