Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a crash and an unclear error message. #351

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,9 @@ def loop(
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
)
inner_while_loop = ft.partial(_inner_loop, kind=kind)
outer_while_loop = ft.partial(_outer_loop, kind=kind)
final_state = self._loop(
Expand Down
61 changes: 39 additions & 22 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings
from collections.abc import Callable
from typing import Any, cast, get_args, get_origin, Optional, Tuple, TYPE_CHECKING
from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING

import equinox as eqx
import equinox.internal as eqxi
Expand Down Expand Up @@ -736,27 +736,44 @@ def _wrap(term):
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
)

if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
if isinstance(solver, AbstractImplicitSolver):
if solver.root_finder.rtol is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.rtol,
solver,
stepsize_controller.rtol,
)
solver = cast(AbstractImplicitSolver, solver)
if solver.root_finder.atol is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.atol,
solver,
stepsize_controller.atol,
)
solver = cast(AbstractImplicitSolver, solver)
if solver.root_finder.norm is use_stepsize_tol:
solver = eqx.tree_at(
lambda s: s.root_finder.norm,
solver,
stepsize_controller.norm,
if isinstance(solver, AbstractImplicitSolver):

def _get_tols(x):
outs = []
for attr in ("rtol", "atol", "norm"):
if getattr(solver.root_finder, attr) is use_stepsize_tol: # pyright: ignore
outs.append(getattr(x, attr))
return tuple(outs)

if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
solver = eqx.tree_at(
lambda s: _get_tols(s.root_finder),
solver,
_get_tols(stepsize_controller),
)
else:
if len(_get_tols(solver.root_finder)) > 0:
raise ValueError(
"A fixed step size controller is being used alongside an implicit "
"solver, but the tolerances for the implicit solver have not been "
"specified. (Being unspecified is the default in Diffrax.)\n"
"The correct fix is almost always to use an adaptive step size "
"controller. For example "
"`diffrax.diffeqsolve(..., "
"stepsize_controller=diffrax.PIDController(rtol=..., atol=...))`. "
"In this case the same tolerances are used for the implicit "
"solver as are used to control the adaptive stepping.\n"
"(Note for advanced users: the tolerances for the implicit "
"solver can also be explicitly set instead. For example "
"`diffrax.diffeqsolve(..., solver=diffrax.Kvaerno5(root_finder="
"diffrax.VeryChord(rtol=..., atol=..., "
"norm=optimistix.max_norm)))`. In this case the norm must also be "
"explicitly specified.)\n"
"Adaptive step size controllers are the preferred solution, as "
"sometimes the implicit solver may fail to converge, and in this "
"case an adaptive step size controller can reject the step and try "
"a smaller one, whilst with a fixed step size controller the "
"overall differential equation solve will simply fail."
)

# Error checking
Expand Down
10 changes: 9 additions & 1 deletion diffrax/_root_finder/_with_tols.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import optimistix as optx


use_stepsize_tol = object()
class _UseStepSizeTol:
def __repr__(self):
return (
"<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` "
"argument>"
)


use_stepsize_tol = _UseStepSizeTol()


def with_stepsize_controller_tols(cls: type[optx.AbstractRootFinder]):
Expand Down
13 changes: 13 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,16 @@ def run(y0__args, adjoint):
grads3 = run((y0, args), diffrax.RecursiveCheckpointAdjoint())
assert tree_allclose(grads1, grads2, rtol=1e-3, atol=1e-3)
assert tree_allclose(grads1, grads3, rtol=1e-3, atol=1e-3)


def test_implicit_runge_kutta_direct_adjoint():
diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Kvaerno5(),
0,
1,
0.01,
1.0,
adjoint=diffrax.DirectAdjoint(),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
)
13 changes: 13 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,16 @@ def vector_field(t, y, args):
assert text == "static_made_jump=False static_result=None\n"
finally:
diffrax._integrate._PRINT_STATIC = False


def test_implicit_tol_error():
msg = "the tolerances for the implicit solver have not been specified"
with pytest.raises(ValueError, match=msg):
diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: -y),
diffrax.Kvaerno5(),
0,
1,
0.01,
1.0,
)
Loading