Skip to content

Commit

Permalink
start working on compatibility check
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Feb 14, 2025
1 parent 3a9dab1 commit d7e30e4
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 32 deletions.
14 changes: 14 additions & 0 deletions cpp/basix/finite-element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,20 @@ element::make_discontinuous(
/// @endcond
//-----------------------------------------------------------------------------
template <std::floating_point T>
bool element::compatible(const basix::FiniteElement<T>& e0, const basix::FiniteElement<T>& e1) {
// TODO
return false;
if (e0.value_shape() != e1.value_shape())
return false;
return true;
}
//-----------------------------------------------------------------------------
/// @cond
template bool element::compatible(const basix::FiniteElement<float>& e0, const basix::FiniteElement<float>& e1);
template bool element::compatible(const basix::FiniteElement<double>& e0, const basix::FiniteElement<double>& e1);
/// @endcond
//-----------------------------------------------------------------------------
template <std::floating_point T>
FiniteElement<T> basix::create_custom_element(
cell::type cell_type, const std::vector<std::size_t>& value_shape,
impl::mdspan_t<const T, 2> wcoeffs,
Expand Down
70 changes: 38 additions & 32 deletions cpp/basix/finite-element.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,38 +96,6 @@ to_mdspan(const std::array<std::vector<std::vector<T>>, 4>& M,

} // namespace impl

namespace element
{
/// Typedef for mdspan
template <typename T, std::size_t d>
using mdspan_t = impl::mdspan_t<T, d>;

/// Create a version of the interpolation points, interpolation
/// matrices and entity transformation that represent a discontinuous
/// version of the element. This discontinuous version will have the
/// same DOFs but they will all be associated with the interior of the
/// reference cell.
/// @param[in] x Interpolation points. Indices are (tdim, entity index,
/// point index, dim)
/// @param[in] M The interpolation matrices. Indices are (tdim, entity
/// index, dof, vs, point_index, derivative)
/// @param[in] tdim The topological dimension of the cell the element is
/// defined on
/// @param[in] value_size The value size of the element
/// @return (xdata, xshape, Mdata, Mshape), where the x and M data are
/// for a discontinuous version of the element (with the same shapes as
/// x and M)
template <std::floating_point T>
std::tuple<std::array<std::vector<std::vector<T>>, 4>,
std::array<std::vector<std::array<std::size_t, 2>>, 4>,
std::array<std::vector<std::vector<T>>, 4>,
std::array<std::vector<std::array<std::size_t, 4>>, 4>>
make_discontinuous(const std::array<std::vector<mdspan_t<const T, 2>>, 4>& x,
const std::array<std::vector<mdspan_t<const T, 4>>, 4>& M,
std::size_t tdim, std::size_t value_size);

} // namespace element

/// @brief A finite element.
///
/// The basis of a finite element is stored as a set of coefficients,
Expand Down Expand Up @@ -1562,6 +1530,44 @@ class FiniteElement
std::array<array4_t, 4> _M;
};

namespace element
{
/// Typedef for mdspan
template <typename T, std::size_t d>
using mdspan_t = impl::mdspan_t<T, d>;

/// Create a version of the interpolation points, interpolation
/// matrices and entity transformation that represent a discontinuous
/// version of the element. This discontinuous version will have the
/// same DOFs but they will all be associated with the interior of the
/// reference cell.
/// @param[in] x Interpolation points. Indices are (tdim, entity index,
/// point index, dim)
/// @param[in] M The interpolation matrices. Indices are (tdim, entity
/// index, dof, vs, point_index, derivative)
/// @param[in] tdim The topological dimension of the cell the element is
/// defined on
/// @param[in] value_size The value size of the element
/// @return (xdata, xshape, Mdata, Mshape), where the x and M data are
/// for a discontinuous version of the element (with the same shapes as
/// x and M)
template <std::floating_point T>
std::tuple<std::array<std::vector<std::vector<T>>, 4>,
std::array<std::vector<std::array<std::size_t, 2>>, 4>,
std::array<std::vector<std::vector<T>>, 4>,
std::array<std::vector<std::array<std::size_t, 4>>, 4>>
make_discontinuous(const std::array<std::vector<mdspan_t<const T, 2>>, 4>& x,
const std::array<std::vector<mdspan_t<const T, 4>>, 4>& M,
std::size_t tdim, std::size_t value_size);

/// Check if two elements are compatible
/// @param[in] e0 The first element
/// @param[in] e1 The second element
template <std::floating_point T>
bool compatible(const FiniteElement<T>& e0, const FiniteElement<T>& e1);

} // namespace element

/// Create a custom finite element
/// @param[in] cell_type The cell type
/// @param[in] value_shape The value shape of the element
Expand Down
16 changes: 16 additions & 0 deletions python/basix/finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from basix._basixcpp import create_tp_element as _create_tp_element
from basix._basixcpp import tp_dof_ordering as _tp_dof_ordering
from basix._basixcpp import tp_factors as _tp_factors
from basix._basixcpp import compatible as _compatible
from basix.cell import CellType
from basix import MapType
from basix.polynomials import PolysetType
Expand Down Expand Up @@ -870,3 +871,18 @@ def string_to_family(family: str, cell: str) -> ElementFamily:
return families[family]
except KeyError:
raise ValueError(f"Unknown element family: {family} with cell type {cell}")


def compatible(element0: FiniteElement, element1: FiniteElement) -> bool:
"""Check if two elements are compatible.
Two elements are compatible if their DOFs on and shared sub-entities are the same.
Args:
element0: The first element.
element1: The second element.
Returns:
True or false.
"""
return _compatible(element0._e, element1._e)
6 changes: 6 additions & 0 deletions python/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ void declare_float(nb::module_& m, std::string type)
return as_nbarrayp(polyset::tabulate(celltype, polytype, d, n, _x));
},
"celltype"_a, "polytype"_a, "d"_a, "n"_a, "x"_a.noconvert());

m.def(
"compatible",
[](const FiniteElement<T>& e0, const FiniteElement<T>& e1)
{ return basix::element::compatible(e0, e1); },
"e0"_a, "e1"_a);
}

} // namespace
Expand Down
83 changes: 83 additions & 0 deletions test/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,86 @@ def test_hash():
for i, d0 in enumerate(different_elements):
for d1 in different_elements[:i]:
assert hash(d0) != hash(d1)


@pytest.mark.parametrize("family", [
basix.ElementFamily.P,
basix.ElementFamily.RT,
basix.ElementFamily.N1E,
basix.ElementFamily.HHJ,
])
@pytest.mark.parametrize("degree", [1, 2])
@pytest.mark.parametrize("cell0", [
basix.CellType.quadrilateral, basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.hexahedron,
])
@pytest.mark.parametrize("cell1", [
basix.CellType.quadrilateral, basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.hexahedron,
])
def test_compatible_same_type(family, degree, cell0, cell1):
assert basix.finite_element.compatible(
basix.create_element(family, cell0, degree),
basix.create_element(family, cell1, degree),
)


@pytest.mark.parametrize("family", [
basix.ElementFamily.P,
basix.ElementFamily.RT,
basix.ElementFamily.N1E,
basix.ElementFamily.HHJ,
])
@pytest.mark.parametrize("cell0", [basix.CellType.quadrilateral, basix.CellType.triangle])
@pytest.mark.parametrize("cell1", [basix.CellType.quadrilateral, basix.CellType.triangle])
@pytest.mark.parametrize("degree0,degree1", [(1, 2), (2, 1)])
def test_compatible_different_degree(family, cell0, cell1, degree0, degree1):
assert not basix.finite_element.compatible(
basix.create_element(basix.ElementFamily.P, cell0, degree0),
basix.create_element(basix.ElementFamily.P, cell1, degree1),
)


@pytest.mark.parametrize("cell0", [
basix.CellType.quadrilateral,
basix.CellType.hexahedron,
])
@pytest.mark.parametrize("cell1", [
basix.CellType.quadrilateral, basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.hexahedron,
])
def test_compatible_lagrange_serendipity(cell0, cell1):
assert not basix.finite_element.compatible(
basix.create_element(basix.ElementFamily.serendipity, cell0, 1),
basix.create_element(basix.ElementFamily.P, cell1, 1),
)

@pytest.mark.parametrize("degree", range(1, 5))
@pytest.mark.parametrize("cell0", [
basix.CellType.quadrilateral, basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.hexahedron,
])
@pytest.mark.parametrize("cell1", [
basix.CellType.quadrilateral, basix.CellType.triangle,
basix.CellType.tetrahedron,
basix.CellType.hexahedron,
])
@pytest.mark.parametrize("variant0", [
basix.LagrangeVariant.equispaced,
basix.LagrangeVariant.gll_isaac,
basix.LagrangeVariant.gll_centroid,
])
@pytest.mark.parametrize("variant1", [
basix.LagrangeVariant.equispaced,
basix.LagrangeVariant.gll_isaac,
basix.LagrangeVariant.gll_centroid,
])
def test_compatible_variants(degree, cell0, cell1, variant0, variant1):
assert basix.finite_element.compatible(
basix.create_element(basix.ElementFamily.P, cell0, degree, lagrange_variant=variant0),
basix.create_element(basix.ElementFamily.P, cell1, degree, lagrange_variant=variant1),
) == (variant0 == variant1)

0 comments on commit d7e30e4

Please sign in to comment.