Skip to content

Commit

Permalink
Resolve jacobians from duplicate meshes (#733)
Browse files Browse the repository at this point in the history
* Fix multiple jacobians

* add test that currently fails on main

* ruff

---------

Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>
  • Loading branch information
jorgensd and mscroggs authored Feb 23, 2025
1 parent c78d7ea commit cbfc0f0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ffcx/codegeneration/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def x_component(self, mt):

def J_component(self, mt):
"""Jacobian component."""
# FIXME: Add domain number!
return L.Symbol(format_mt_name("J", mt), dtype=L.DataType.REAL)
return L.Symbol(
format_mt_name(f"J{mt.expr.ufl_domain().ufl_id()}", mt), dtype=L.DataType.REAL
)

def domain_dof_access(self, dof, component, gdim, num_scalar_dofs, restriction):
"""Domain DOF access."""
Expand Down
24 changes: 24 additions & 0 deletions test/test_jit_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,30 @@ def test_integral_grouping(compile_args):
assert len(unique_integrals) == 2


def test_derivative_domains(compile_args):
"""Test a form with derivatives on two different domains will generate valid code."""

V_ele = basix.ufl.element("Lagrange", "triangle", 2)
W_ele = basix.ufl.element("Lagrange", "interval", 1)

gdim = 2
V_domain = ufl.Mesh(basix.ufl.element("Lagrange", "triangle", 1, shape=(gdim,)))
W_domain = ufl.Mesh(basix.ufl.element("Lagrange", "interval", 1, shape=(gdim,)))

V = ufl.FunctionSpace(V_domain, V_ele)
W = ufl.FunctionSpace(W_domain, W_ele)

u = ufl.TrialFunction(V)
q = ufl.TestFunction(W)

ds = ufl.Measure("ds", domain=V_domain)

forms = [ufl.inner(u.dx(0), q.dx(0)) * ds]
compiled_forms, module, code = ffcx.codegeneration.jit.compile_forms(
forms, options={"scalar_type": np.float64}, cffi_extra_compile_args=compile_args
)


@pytest.mark.parametrize("dtype", ["float64"])
@pytest.mark.parametrize("permutation", [[0], [1]])
def test_mixed_dim_form(compile_args, dtype, permutation):
Expand Down

0 comments on commit cbfc0f0

Please sign in to comment.