Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
kaelyndunnell committed Mar 3, 2025
1 parent 0af263b commit 829808f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
12 changes: 10 additions & 2 deletions src/festim/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,16 @@ def create_solver(self):
bcs=self.bc_forms,
)
self.solver = NewtonSolver(MPI.COMM_WORLD, problem)
self.solver.atol = self.settings.atol if not callable(self.settings.rtol) else self.settings.rtol(0.0)
self.solver.rtol = self.settings.rtol if not callable(self.settings.rtol) else self.settings.rtol(0.0)
self.solver.atol = (
self.settings.atol
if not callable(self.settings.rtol)
else self.settings.rtol(0.0)
)
self.solver.rtol = (
self.settings.rtol
if not callable(self.settings.rtol)
else self.settings.rtol(0.0)
)
self.solver.max_it = self.settings.max_iterations

ksp = self.solver.krylov_solver
Expand Down
30 changes: 18 additions & 12 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,36 @@ def test_stepsize_value_wrong_type():
with pytest.raises(TypeError):
my_settings.stepsize = "coucou"


@pytest.mark.parametrize(
"rtol", [1e-10, lambda t: 1e-8 if t<10 else 1e-10],
)
def test_callable_rtol(rtol):
"rtol",
[1e-10, lambda t: 1e-8 if t < 10 else 1e-10],
)
def test_callable_rtol(rtol):
"""Tests callable rtol."""
my_settings = F.Settings(atol=0.1, rtol=rtol)

assert my_settings.rtol == rtol

@pytest.mark.parametrize(
"atol", [1e10, lambda t: 1e12 if t<10 else 1e10]
)
def test_callable_atol(atol):

@pytest.mark.parametrize("atol", [1e10, lambda t: 1e12 if t < 10 else 1e10])
def test_callable_atol(atol):
"""Tests callable atol."""
my_settings = F.Settings(atol=atol, rtol=0.1)

assert my_settings.atol == atol


@pytest.mark.parametrize(
"rtol, atol", [(1e10,1e10), (lambda t: 1e-8 if t<10 else 1e-10,lambda t: 1e12 if t<10 else 1e10)]
)
def test_tolerances_solve_before_passed_to_fenics(rtol,atol):
"rtol, atol",
[
(1e10, 1e10),
(lambda t: 1e-8 if t < 10 else 1e-10, lambda t: 1e12 if t < 10 else 1e10),
],
)
def test_tolerances_solve_before_passed_to_fenics(rtol, atol):
"""Tests that the tolerances, if callable, are called & return an integer before passed to fenics"""

# BUILD
test_mesh = F.Mesh1D(vertices=np.array([0.0, 1.0, 2.0, 3.0, 4.0]))
dummy_mat = F.Material(D_0=1, E_D=1, name="dummy_mat")
Expand All @@ -65,4 +71,4 @@ def test_tolerances_solve_before_passed_to_fenics(rtol,atol):
my_model.initialise()

# RUN & TEST
my_model.run()
my_model.run()

0 comments on commit 829808f

Please sign in to comment.