Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to type system #3642

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/demo/biharmonic/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@

using namespace dolfinx;
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;

// Inside the `main` function, we begin by defining a mesh of the
// domain. As the unit square is a very standard domain, we can use a
Expand Down
2 changes: 1 addition & 1 deletion cpp/demo/codim_0_assembly/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

using namespace dolfinx;
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;

int main(int argc, char* argv[])
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/demo/hyperelasticity/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

using namespace dolfinx;
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;

/// Hyperelastic problem class
class HyperElasticProblem
Expand Down
2 changes: 1 addition & 1 deletion cpp/demo/mixed_poisson/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@

using namespace dolfinx;
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;

int main(int argc, char* argv[])
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/demo/poisson/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@

using namespace dolfinx;
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;

// Then follows the definition of the coefficient functions (for $f$ and
// $g$), which are derived from the {cpp:class}`Expression` class in
Expand Down
2 changes: 1 addition & 1 deletion cpp/demo/poisson_matrix_free/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void solver(MPI_Comm comm)
int main(int argc, char* argv[])
{
using T = PetscScalar;
using U = typename dolfinx::scalar_value_type_t<T>;
using U = typename dolfinx::scalar_value_t<T>;
init_logging(argc, argv);
MPI_Init(&argc, &argv);
solver<T, U>(MPI_COMM_WORLD);
Expand Down
27 changes: 18 additions & 9 deletions cpp/dolfinx/common/types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023 Garth N. Wells
// Copyright (C) 2023-2025 Garth N. Wells and Paul T. Kühner
//
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
Expand All @@ -10,30 +10,39 @@
#include <concepts>
#include <type_traits>

#include <basix/mdspan.hpp>

namespace dolfinx
{
/// @private This concept is used to constrain the a template type to floating
/// point real or complex types. Note that this concept is different to
/// std::floating_point which does not include std::complex.
template <class T>
concept scalar
= std::is_floating_point_v<T> || std::is_same_v<T, std::complex<double>>
|| std::is_same_v<T, std::complex<float>>;
concept scalar = std::floating_point<T>
|| std::is_same_v<T, std::complex<typename T::value_type>>;

/// @private These structs are used to get the float/value type from a
/// template argument, including support for complex types.
template <scalar T, typename = void>
struct scalar_value_type
struct scalar_value
{
/// @internal
typedef T value_type;
typedef T type;
};
/// @private
template <scalar T>
struct scalar_value_type<T, std::void_t<typename T::value_type>>
struct scalar_value<T, std::void_t<typename T::value_type>>
{
typedef typename T::value_type value_type;
typedef typename T::value_type type;
};
/// @private Convenience typedef
template <scalar T>
using scalar_value_type_t = typename scalar_value_type<T>::value_type;
using scalar_value_t = typename scalar_value<T>::type;

/// Namespace containing the `mdspan` implementation
namespace md = MDSPAN_IMPL_STANDARD_NAMESPACE;

/// @private DofMap span data layout
using DofMapSpan = md::mdspan<const std::int32_t, md::dextents<std::size_t, 2>>;

} // namespace dolfinx
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/DirichletBC.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ std::array<std::vector<std::int32_t>, 2> locate_dofs_geometrical(
/// space (trial space) and degrees of freedom to which the boundary
/// condition applies.
template <dolfinx::scalar T,
std::floating_point U = dolfinx::scalar_value_type_t<T>>
std::floating_point U = dolfinx::scalar_value_t<T>>
class DirichletBC
{
private:
Expand Down
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Function;
/// @tparam T The scalar type.
/// @tparam U The mesh geometry scalar type.
template <dolfinx::scalar T,
std::floating_point U = dolfinx::scalar_value_type_t<T>>
std::floating_point U = dolfinx::scalar_value_t<T>>
class Expression
{
public:
Expand Down
5 changes: 2 additions & 3 deletions cpp/dolfinx/fem/Form.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ std::vector<std::int32_t> compute_domain(
/// @brief Represents integral data, containing the kernel, and a list
/// of entities to integrate over and the indicies of the coefficient
/// functions (relative to the Form) active for this integral.
template <dolfinx::scalar T, std::floating_point U = scalar_value_type_t<T>>
template <dolfinx::scalar T, std::floating_point U = scalar_value_t<T>>
struct integral_data
{
/// @brief Create a structure to hold integral data.
Expand Down Expand Up @@ -170,8 +170,7 @@ struct integral_data
/// @tparam U Float (real) type used for the finite element and
/// geometry.
/// @tparam Kern Element kernel.
template <dolfinx::scalar T,
std::floating_point U = dolfinx::scalar_value_type_t<T>>
template <dolfinx::scalar T, std::floating_point U = dolfinx::scalar_value_t<T>>
class Form
{
public:
Expand Down
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Expression;
/// @tparam T The function scalar type.
/// @tparam U The mesh geometry scalar type.
template <dolfinx::scalar T,
std::floating_point U = dolfinx::scalar_value_type_t<T>>
std::floating_point U = dolfinx::scalar_value_t<T>>
class Function
{
public:
Expand Down
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/assemble_expression_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void tabulate_expression(
const std::int32_t,
MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
x_dofmap,
std::span<const scalar_value_type_t<T>> x,
std::span<const scalar_value_t<T>> x,
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
coeffs,
Expand Down
37 changes: 16 additions & 21 deletions cpp/dolfinx/fem/assemble_matrix_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "traits.h"
#include "utils.h"
#include <algorithm>
#include <dolfinx/common/types.h>
#include <dolfinx/la/utils.h>
#include <dolfinx/mesh/Geometry.h>
#include <dolfinx/mesh/Mesh.h>
Expand All @@ -24,11 +25,6 @@

namespace dolfinx::fem::impl
{
/// @brief Typedef
using mdspan2_t = MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
const std::int32_t,
MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>;

/// @brief Execute kernel over cells and accumulate result in matrix.
/// @tparam T Matrix/form scalar type.
/// @param mat_set Function that accumulates computed entries into a
Expand Down Expand Up @@ -59,12 +55,11 @@ using mdspan2_t = MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
/// mesh
template <dolfinx::scalar T>
void assemble_cells(
la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x,
std::span<const std::int32_t> cells,
std::tuple<mdspan2_t, int, std::span<const std::int32_t>> dofmap0,
la::MatSet<T> auto mat_set, DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x, std::span<const std::int32_t> cells,
std::tuple<DofMapSpan, int, std::span<const std::int32_t>> dofmap0,
fem::DofTransformKernel<T> auto P0,
std::tuple<mdspan2_t, int, std::span<const std::int32_t>> dofmap1,
std::tuple<DofMapSpan, int, std::span<const std::int32_t>> dofmap1,
fem::DofTransformKernel<T> auto P1T, std::span<const std::int8_t> bc0,
std::span<const std::int8_t> bc1, FEkernel<T> auto kernel,
std::span<const T> coeffs, int cstride, std::span<const T> constants,
Expand All @@ -84,7 +79,7 @@ void assemble_cells(
const int ndim1 = bs1 * num_dofs1;
std::vector<T> Ae(ndim0 * ndim1);
std::span<T> _Ae(Ae);
std::vector<scalar_value_type_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));
std::vector<scalar_value_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));

// Iterate over active cells
assert(cells0.size() == cells.size());
Expand Down Expand Up @@ -193,12 +188,12 @@ void assemble_cells(
/// permutations are not required.
template <dolfinx::scalar T>
void assemble_exterior_facets(
la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x, int num_facets_per_cell,
la::MatSet<T> auto mat_set, DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x, int num_facets_per_cell,
std::span<const std::int32_t> facets,
std::tuple<mdspan2_t, int, std::span<const std::int32_t>> dofmap0,
std::tuple<DofMapSpan, int, std::span<const std::int32_t>> dofmap0,
fem::DofTransformKernel<T> auto P0,
std::tuple<mdspan2_t, int, std::span<const std::int32_t>> dofmap1,
std::tuple<DofMapSpan, int, std::span<const std::int32_t>> dofmap1,
fem::DofTransformKernel<T> auto P1T, std::span<const std::int8_t> bc0,
std::span<const std::int8_t> bc1, FEkernel<T> auto kernel,
std::span<const T> coeffs, int cstride, std::span<const T> constants,
Expand All @@ -213,7 +208,7 @@ void assemble_exterior_facets(
const auto [dmap1, bs1, facets1] = dofmap1;

// Data structures used in assembly
std::vector<scalar_value_type_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));
std::vector<scalar_value_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));
const int num_dofs0 = dmap0.extent(1);
const int num_dofs1 = dmap1.extent(1);
const int ndim0 = bs0 * num_dofs0;
Expand Down Expand Up @@ -332,8 +327,8 @@ void assemble_exterior_facets(
/// permutations are not required.
template <dolfinx::scalar T>
void assemble_interior_facets(
la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x, int num_facets_per_cell,
la::MatSet<T> auto mat_set, DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x, int num_facets_per_cell,
std::span<const std::int32_t> facets,
std::tuple<const DofMap&, int, std::span<const std::int32_t>> dofmap0,
fem::DofTransformKernel<T> auto P0,
Expand All @@ -352,7 +347,7 @@ void assemble_interior_facets(
const auto [dmap1, bs1, facets1] = dofmap1;

// Data structures used in assembly
using X = scalar_value_type_t<T>;
using X = scalar_value_t<T>;
std::vector<X> coordinate_dofs(2 * x_dofmap.extent(1) * 3);
std::span<X> cdofs0(coordinate_dofs.data(), x_dofmap.extent(1) * 3);
std::span<X> cdofs1(coordinate_dofs.data() + x_dofmap.extent(1) * 3,
Expand Down Expand Up @@ -493,7 +488,7 @@ void assemble_interior_facets(
template <dolfinx::scalar T, std::floating_point U>
void assemble_matrix(
la::MatSet<T> auto mat_set, const Form<T, U>& a,
std::span<const scalar_value_type_t<T>> x, std::span<const T> constants,
std::span<const scalar_value_t<T>> x, std::span<const T> constants,
const std::map<std::pair<IntegralType, int>,
std::pair<std::span<const T>, int>>& coefficients,
std::span<const std::int8_t> bc0, std::span<const std::int8_t> bc1)
Expand All @@ -520,7 +515,7 @@ void assemble_matrix(
for (int cell_type_idx = 0; cell_type_idx < num_cell_types; ++cell_type_idx)
{
// Geometry dofmap and data
mdspan2_t x_dofmap = mesh->geometry().dofmap(cell_type_idx);
DofMapSpan x_dofmap = mesh->geometry().dofmap(cell_type_idx);

// Get dofmap data
std::shared_ptr<const fem::DofMap> dofmap0
Expand Down
21 changes: 11 additions & 10 deletions cpp/dolfinx/fem/assemble_scalar_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "utils.h"
#include <algorithm>
#include <dolfinx/common/IndexMap.h>
#include <dolfinx/common/types.h>
#include <dolfinx/mesh/Geometry.h>
#include <dolfinx/mesh/Mesh.h>
#include <dolfinx/mesh/Topology.h>
Expand All @@ -22,7 +23,7 @@ namespace dolfinx::fem::impl
{
/// Assemble functional over cells
template <dolfinx::scalar T>
T assemble_cells(mdspan2_t x_dofmap, std::span<const scalar_value_type_t<T>> x,
T assemble_cells(DofMapSpan x_dofmap, std::span<const scalar_value_t<T>> x,
std::span<const std::int32_t> cells, FEkernel<T> auto fn,
std::span<const T> constants, std::span<const T> coeffs,
int cstride)
Expand All @@ -32,7 +33,7 @@ T assemble_cells(mdspan2_t x_dofmap, std::span<const scalar_value_type_t<T>> x,
return value;

// Create data structures used in assembly
std::vector<scalar_value_type_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));
std::vector<scalar_value_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));

// Iterate over all cells
for (std::size_t index = 0; index < cells.size(); ++index)
Expand All @@ -58,8 +59,8 @@ T assemble_cells(mdspan2_t x_dofmap, std::span<const scalar_value_type_t<T>> x,

/// Execute kernel over exterior facets and accumulate result
template <dolfinx::scalar T>
T assemble_exterior_facets(mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x,
T assemble_exterior_facets(DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x,
int num_facets_per_cell,
std::span<const std::int32_t> facets,
FEkernel<T> auto fn, std::span<const T> constants,
Expand All @@ -71,7 +72,7 @@ T assemble_exterior_facets(mdspan2_t x_dofmap,
return value;

// Create data structures used in assembly
std::vector<scalar_value_type_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));
std::vector<scalar_value_t<T>> coordinate_dofs(3 * x_dofmap.extent(1));

// Iterate over all facets
assert(facets.size() % 2 == 0);
Expand Down Expand Up @@ -102,8 +103,8 @@ T assemble_exterior_facets(mdspan2_t x_dofmap,

/// Assemble functional over interior facets
template <dolfinx::scalar T>
T assemble_interior_facets(mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x,
T assemble_interior_facets(DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x,
int num_facets_per_cell,
std::span<const std::int32_t> facets,
FEkernel<T> auto fn, std::span<const T> constants,
Expand All @@ -116,7 +117,7 @@ T assemble_interior_facets(mdspan2_t x_dofmap,
return value;

// Create data structures used in assembly
using X = scalar_value_type_t<T>;
using X = scalar_value_t<T>;
std::vector<X> coordinate_dofs(2 * x_dofmap.extent(1) * 3);
std::span<X> cdofs0(coordinate_dofs.data(), x_dofmap.extent(1) * 3);
std::span<X> cdofs1(coordinate_dofs.data() + x_dofmap.extent(1) * 3,
Expand Down Expand Up @@ -165,8 +166,8 @@ T assemble_interior_facets(mdspan2_t x_dofmap,
/// Assemble functional into an scalar with provided mesh geometry.
template <dolfinx::scalar T, std::floating_point U>
T assemble_scalar(
const fem::Form<T, U>& M, mdspan2_t x_dofmap,
std::span<const scalar_value_type_t<T>> x, std::span<const T> constants,
const fem::Form<T, U>& M, DofMapSpan x_dofmap,
std::span<const scalar_value_t<T>> x, std::span<const T> constants,
const std::map<std::pair<IntegralType, int>,
std::pair<std::span<const T>, int>>& coefficients)
{
Expand Down
Loading
Loading