Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows tolerances to be callable #958

Open
wants to merge 19 commits into
base: fenicsx
Choose a base branch
from
18 changes: 16 additions & 2 deletions src/festim/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,16 @@ def create_solver(self):
bcs=self.bc_forms,
)
self.solver = NewtonSolver(MPI.COMM_WORLD, problem)
self.solver.atol = self.settings.atol
self.solver.rtol = self.settings.rtol
self.solver.atol = (
self.settings.atol
if not callable(self.settings.rtol)
else self.settings.rtol(0.0)
)
self.solver.rtol = (
self.settings.rtol
if not callable(self.settings.rtol)
else self.settings.rtol(0.0)
)
self.solver.max_it = self.settings.max_iterations

ksp = self.solver.krylov_solver
Expand Down Expand Up @@ -151,6 +159,12 @@ def run(self):
unit_scale=True,
)
while self.t.value < self.settings.final_time:
# update rtol if it's callable
if callable(self.settings.rtol):
self.solver.rtol = self.settings.rtol(self.t.value)
# update rtol if it's callable
if callable(self.settings.atol):
self.solver.atol = self.settings.atol(self.t.value)
self.iterate()
if self.show_progress_bar:
self.progress_bar.refresh() # refresh progress bar to show 100%
Expand Down
51 changes: 51 additions & 0 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,54 @@ def test_stepsize_value_wrong_type():

with pytest.raises(TypeError):
my_settings.stepsize = "coucou"


@pytest.mark.parametrize(
"rtol",
[1e-10, lambda t: 1e-8 if t < 10 else 1e-10],
)
def test_callable_rtol(rtol):
"""Tests callable rtol."""
my_settings = F.Settings(atol=0.1, rtol=rtol)

assert my_settings.rtol == rtol


@pytest.mark.parametrize("atol", [1e10, lambda t: 1e12 if t < 10 else 1e10])
def test_callable_atol(atol):
"""Tests callable atol."""
my_settings = F.Settings(atol=atol, rtol=0.1)

assert my_settings.atol == atol


@pytest.mark.parametrize(
"rtol, atol",
[
(1e10, 1e10),
(lambda t: 1e-8 if t < 10 else 1e-10, lambda t: 1e12 if t < 10 else 1e10),
],
)
def test_tolerances_solve_before_passed_to_fenics(rtol, atol):
"""Tests that the tolerances, if callable, are called & return an integer before passed to fenics"""

# BUILD
test_mesh = F.Mesh1D(vertices=np.array([0.0, 1.0, 2.0, 3.0, 4.0]))
dummy_mat = F.Material(D_0=1, E_D=1, name="dummy_mat")

my_vol = F.VolumeSubdomain1D(id=1, borders=[0, 4], material=dummy_mat)
my_model = F.HydrogenTransportProblem(
mesh=test_mesh,
settings=F.Settings(atol=atol, rtol=rtol, transient=True, final_time=10),
subdomains=[my_vol],
temperature=300,
)
H = F.Species("H")
my_model.species = [H]

my_model.sources = [F.ParticleSource(value=1e20, volume=my_vol, species=H)]
my_model.settings.stepsize = F.Stepsize(0.05, milestones=[0.1, 0.2, 0.5, 1]) # s
my_model.initialise()

# RUN & TEST
my_model.run()
Loading