Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows tolerances to be callable #958

Open
wants to merge 19 commits into
base: fenicsx
Choose a base branch
from
20 changes: 18 additions & 2 deletions src/festim/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,16 @@ def create_solver(self):
bcs=self.bc_forms,
)
self.solver = NewtonSolver(MPI.COMM_WORLD, problem)
self.solver.atol = self.settings.atol
self.solver.rtol = self.settings.rtol
self.solver.atol = (
self.settings.atol
if not callable(self.settings.rtol)
else self.settings.rtol(float(self.t))
)
self.solver.rtol = (
self.settings.rtol
if not callable(self.settings.rtol)
else self.settings.rtol(float(self.t))
)
self.solver.max_it = self.settings.max_iterations

ksp = self.solver.krylov_solver
Expand Down Expand Up @@ -166,6 +174,14 @@ def iterate(self):
self.progress_bar.update(
min(self.dt.value, abs(self.settings.final_time - self.t.value))
)

# update rtol if it's callable
if callable(self.settings.rtol):
self.solver.rtol = self.settings.rtol(self.t.value)
# update rtol if it's callable
if callable(self.settings.atol):
self.solver.atol = self.settings.atol(self.t.value)

self.t.value += self.dt.value

self.update_time_dependent_values()
Expand Down
8 changes: 4 additions & 4 deletions src/festim/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class Settings:
"""Settings for a festim simulation.

Args:
atol (float): Absolute tolerance for the solver.
rtol (float): Relative tolerance for the solver.
atol (float or callable): Absolute tolerance for the solver.
rtol (float or callable): Relative tolerance for the solver.
max_iterations (int, optional): Maximum number of iterations for the
solver. Defaults to 30.
transient (bool, optional): Whether the simulation is transient or not.
Expand All @@ -19,8 +19,8 @@ class Settings:
convergence_criterion: resiudal or incremental (for Newton solver)

Attributes:
atol (float): Absolute tolerance for the solver.
rtol (float): Relative tolerance for the solver.
atol (float or callable): Absolute tolerance for the solver.
rtol (float or callable): Relative tolerance for the solver.
max_iterations (int): Maximum number of iterations for the solver.
transient (bool): Whether the simulation is transient or not.
final_time (float): Final time for a transient simulation.
Expand Down
60 changes: 60 additions & 0 deletions test/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,63 @@ def test_stepsize_value_wrong_type():

with pytest.raises(TypeError):
my_settings.stepsize = "coucou"


@pytest.mark.parametrize(
"rtol",
[1e-10, lambda t: 1e-8 if t < 10 else 1e-10],
)
def test_callable_rtol(rtol):
"""Tests callable rtol."""
my_settings = F.Settings(atol=0.1, rtol=rtol)

assert my_settings.rtol == rtol


@pytest.mark.parametrize("atol", [1e10, lambda t: 1e12 if t < 10 else 1e10])
def test_callable_atol(atol):
"""Tests callable atol."""
my_settings = F.Settings(atol=atol, rtol=0.1)

assert my_settings.atol == atol


@pytest.mark.parametrize(
"rtol, atol",
[
(lambda t: 1e-8 if t < 10 else 1e-10, lambda t: 1e12 if t < 10 else 1e10),
],
)
def test_tolerances_value(rtol, atol):
"""Tests that callable tolerances are called & return correct float before passed to fenics"""

# BUILD
test_mesh = F.Mesh1D(vertices=np.array([0.0, 1.0, 2.0, 3.0, 4.0]))
dummy_mat = F.Material(D_0=1, E_D=1, name="dummy_mat")

my_vol = F.VolumeSubdomain1D(id=1, borders=[0, 4], material=dummy_mat)
my_model = F.HydrogenTransportProblem(
mesh=test_mesh,
settings=F.Settings(atol=atol, rtol=rtol, transient=True, final_time=10),
subdomains=[my_vol],
temperature=300,
)
H = F.Species("H")
my_model.species = [H]

my_model.sources = [F.ParticleSource(value=1e20, volume=my_vol, species=H)]
my_model.settings.stepsize = F.Stepsize(0.05, milestones=[0.1, 0.2, 0.5, 1]) # s
my_model.initialise()

my_model.t.value = 0.0
my_model.iterate()
# check at t=0
assert my_model.solver.atol == atol(t=0.0)
assert my_model.solver.rtol == rtol(t=0.0)

my_model.t.value = 20
my_model.iterate()

# check at t=20
assert my_model.solver.atol == atol(t=20.0)
assert my_model.solver.rtol == rtol(t=20.0)
Loading