From d7e30e4663f70885ed3ab8520b67e7378d082dd4 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 14 Feb 2025 15:28:34 +0000 Subject: [PATCH] start working on compatibility check --- cpp/basix/finite-element.cpp | 14 ++++++ cpp/basix/finite-element.h | 70 +++++++++++++++------------- python/basix/finite_element.py | 16 +++++++ python/wrapper.cpp | 6 +++ test/test_elements.py | 83 ++++++++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 32 deletions(-) diff --git a/cpp/basix/finite-element.cpp b/cpp/basix/finite-element.cpp index 411a61fd2..bbe8da29a 100644 --- a/cpp/basix/finite-element.cpp +++ b/cpp/basix/finite-element.cpp @@ -608,6 +608,20 @@ element::make_discontinuous( /// @endcond //----------------------------------------------------------------------------- template +bool element::compatible(const basix::FiniteElement& e0, const basix::FiniteElement& e1) { + // TODO + return false; + if (e0.value_shape() != e1.value_shape()) + return false; + return true; +} +//----------------------------------------------------------------------------- +/// @cond +template bool element::compatible(const basix::FiniteElement& e0, const basix::FiniteElement& e1); +template bool element::compatible(const basix::FiniteElement& e0, const basix::FiniteElement& e1); +/// @endcond +//----------------------------------------------------------------------------- +template FiniteElement basix::create_custom_element( cell::type cell_type, const std::vector& value_shape, impl::mdspan_t wcoeffs, diff --git a/cpp/basix/finite-element.h b/cpp/basix/finite-element.h index 9fe29e9e4..fe3fc8539 100644 --- a/cpp/basix/finite-element.h +++ b/cpp/basix/finite-element.h @@ -96,38 +96,6 @@ to_mdspan(const std::array>, 4>& M, } // namespace impl -namespace element -{ -/// Typedef for mdspan -template -using mdspan_t = impl::mdspan_t; - -/// Create a version of the interpolation points, interpolation -/// matrices and entity transformation that represent a discontinuous -/// version of the element. This discontinuous version will have the -/// same DOFs but they will all be associated with the interior of the -/// reference cell. -/// @param[in] x Interpolation points. Indices are (tdim, entity index, -/// point index, dim) -/// @param[in] M The interpolation matrices. Indices are (tdim, entity -/// index, dof, vs, point_index, derivative) -/// @param[in] tdim The topological dimension of the cell the element is -/// defined on -/// @param[in] value_size The value size of the element -/// @return (xdata, xshape, Mdata, Mshape), where the x and M data are -/// for a discontinuous version of the element (with the same shapes as -/// x and M) -template -std::tuple>, 4>, - std::array>, 4>, - std::array>, 4>, - std::array>, 4>> -make_discontinuous(const std::array>, 4>& x, - const std::array>, 4>& M, - std::size_t tdim, std::size_t value_size); - -} // namespace element - /// @brief A finite element. /// /// The basis of a finite element is stored as a set of coefficients, @@ -1562,6 +1530,44 @@ class FiniteElement std::array _M; }; +namespace element +{ +/// Typedef for mdspan +template +using mdspan_t = impl::mdspan_t; + +/// Create a version of the interpolation points, interpolation +/// matrices and entity transformation that represent a discontinuous +/// version of the element. This discontinuous version will have the +/// same DOFs but they will all be associated with the interior of the +/// reference cell. +/// @param[in] x Interpolation points. Indices are (tdim, entity index, +/// point index, dim) +/// @param[in] M The interpolation matrices. Indices are (tdim, entity +/// index, dof, vs, point_index, derivative) +/// @param[in] tdim The topological dimension of the cell the element is +/// defined on +/// @param[in] value_size The value size of the element +/// @return (xdata, xshape, Mdata, Mshape), where the x and M data are +/// for a discontinuous version of the element (with the same shapes as +/// x and M) +template +std::tuple>, 4>, + std::array>, 4>, + std::array>, 4>, + std::array>, 4>> +make_discontinuous(const std::array>, 4>& x, + const std::array>, 4>& M, + std::size_t tdim, std::size_t value_size); + +/// Check if two elements are compatible +/// @param[in] e0 The first element +/// @param[in] e1 The second element +template +bool compatible(const FiniteElement& e0, const FiniteElement& e1); + +} // namespace element + /// Create a custom finite element /// @param[in] cell_type The cell type /// @param[in] value_shape The value shape of the element diff --git a/python/basix/finite_element.py b/python/basix/finite_element.py index 31c3e0310..714ef7cee 100644 --- a/python/basix/finite_element.py +++ b/python/basix/finite_element.py @@ -23,6 +23,7 @@ from basix._basixcpp import create_tp_element as _create_tp_element from basix._basixcpp import tp_dof_ordering as _tp_dof_ordering from basix._basixcpp import tp_factors as _tp_factors +from basix._basixcpp import compatible as _compatible from basix.cell import CellType from basix import MapType from basix.polynomials import PolysetType @@ -870,3 +871,18 @@ def string_to_family(family: str, cell: str) -> ElementFamily: return families[family] except KeyError: raise ValueError(f"Unknown element family: {family} with cell type {cell}") + + +def compatible(element0: FiniteElement, element1: FiniteElement) -> bool: + """Check if two elements are compatible. + + Two elements are compatible if their DOFs on and shared sub-entities are the same. + + Args: + element0: The first element. + element1: The second element. + + Returns: + True or false. + """ + return _compatible(element0._e, element1._e) diff --git a/python/wrapper.cpp b/python/wrapper.cpp index c6a2cc5f8..8f8d16609 100644 --- a/python/wrapper.cpp +++ b/python/wrapper.cpp @@ -413,6 +413,12 @@ void declare_float(nb::module_& m, std::string type) return as_nbarrayp(polyset::tabulate(celltype, polytype, d, n, _x)); }, "celltype"_a, "polytype"_a, "d"_a, "n"_a, "x"_a.noconvert()); + + m.def( + "compatible", + [](const FiniteElement& e0, const FiniteElement& e1) + { return basix::element::compatible(e0, e1); }, + "e0"_a, "e1"_a); } } // namespace diff --git a/test/test_elements.py b/test/test_elements.py index 72f5b7591..0f067b29f 100644 --- a/test/test_elements.py +++ b/test/test_elements.py @@ -229,3 +229,86 @@ def test_hash(): for i, d0 in enumerate(different_elements): for d1 in different_elements[:i]: assert hash(d0) != hash(d1) + + +@pytest.mark.parametrize("family", [ + basix.ElementFamily.P, + basix.ElementFamily.RT, + basix.ElementFamily.N1E, + basix.ElementFamily.HHJ, +]) +@pytest.mark.parametrize("degree", [1, 2]) +@pytest.mark.parametrize("cell0", [ + basix.CellType.quadrilateral, basix.CellType.triangle, + basix.CellType.tetrahedron, + basix.CellType.hexahedron, +]) +@pytest.mark.parametrize("cell1", [ + basix.CellType.quadrilateral, basix.CellType.triangle, + basix.CellType.tetrahedron, + basix.CellType.hexahedron, +]) +def test_compatible_same_type(family, degree, cell0, cell1): + assert basix.finite_element.compatible( + basix.create_element(family, cell0, degree), + basix.create_element(family, cell1, degree), + ) + + +@pytest.mark.parametrize("family", [ + basix.ElementFamily.P, + basix.ElementFamily.RT, + basix.ElementFamily.N1E, + basix.ElementFamily.HHJ, +]) +@pytest.mark.parametrize("cell0", [basix.CellType.quadrilateral, basix.CellType.triangle]) +@pytest.mark.parametrize("cell1", [basix.CellType.quadrilateral, basix.CellType.triangle]) +@pytest.mark.parametrize("degree0,degree1", [(1, 2), (2, 1)]) +def test_compatible_different_degree(family, cell0, cell1, degree0, degree1): + assert not basix.finite_element.compatible( + basix.create_element(basix.ElementFamily.P, cell0, degree0), + basix.create_element(basix.ElementFamily.P, cell1, degree1), + ) + + +@pytest.mark.parametrize("cell0", [ + basix.CellType.quadrilateral, + basix.CellType.hexahedron, +]) +@pytest.mark.parametrize("cell1", [ + basix.CellType.quadrilateral, basix.CellType.triangle, + basix.CellType.tetrahedron, + basix.CellType.hexahedron, +]) +def test_compatible_lagrange_serendipity(cell0, cell1): + assert not basix.finite_element.compatible( + basix.create_element(basix.ElementFamily.serendipity, cell0, 1), + basix.create_element(basix.ElementFamily.P, cell1, 1), + ) + +@pytest.mark.parametrize("degree", range(1, 5)) +@pytest.mark.parametrize("cell0", [ + basix.CellType.quadrilateral, basix.CellType.triangle, + basix.CellType.tetrahedron, + basix.CellType.hexahedron, +]) +@pytest.mark.parametrize("cell1", [ + basix.CellType.quadrilateral, basix.CellType.triangle, + basix.CellType.tetrahedron, + basix.CellType.hexahedron, +]) +@pytest.mark.parametrize("variant0", [ + basix.LagrangeVariant.equispaced, + basix.LagrangeVariant.gll_isaac, + basix.LagrangeVariant.gll_centroid, +]) +@pytest.mark.parametrize("variant1", [ + basix.LagrangeVariant.equispaced, + basix.LagrangeVariant.gll_isaac, + basix.LagrangeVariant.gll_centroid, +]) +def test_compatible_variants(degree, cell0, cell1, variant0, variant1): + assert basix.finite_element.compatible( + basix.create_element(basix.ElementFamily.P, cell0, degree, lagrange_variant=variant0), + basix.create_element(basix.ElementFamily.P, cell1, degree, lagrange_variant=variant1), + ) == (variant0 == variant1)