diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 94fe6092..3d85ec69 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -368,7 +368,9 @@ def loop( # Support forward-mode autodiff. # TODO: remove this hack once we can JVP through custom_vjps. if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: - solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded") + solver = eqx.tree_at( + lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none + ) inner_while_loop = ft.partial(_inner_loop, kind=kind) outer_while_loop = ft.partial(_outer_loop, kind=kind) final_state = self._loop( diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 74d52102..d33452bb 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -2,7 +2,7 @@ import typing import warnings from collections.abc import Callable -from typing import Any, cast, get_args, get_origin, Optional, Tuple, TYPE_CHECKING +from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING import equinox as eqx import equinox.internal as eqxi @@ -736,27 +736,44 @@ def _wrap(term): is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm), ) - if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): - if isinstance(solver, AbstractImplicitSolver): - if solver.root_finder.rtol is use_stepsize_tol: - solver = eqx.tree_at( - lambda s: s.root_finder.rtol, - solver, - stepsize_controller.rtol, - ) - solver = cast(AbstractImplicitSolver, solver) - if solver.root_finder.atol is use_stepsize_tol: - solver = eqx.tree_at( - lambda s: s.root_finder.atol, - solver, - stepsize_controller.atol, - ) - solver = cast(AbstractImplicitSolver, solver) - if solver.root_finder.norm is use_stepsize_tol: - solver = eqx.tree_at( - lambda s: s.root_finder.norm, - solver, - stepsize_controller.norm, + if isinstance(solver, AbstractImplicitSolver): + + def _get_tols(x): + outs = [] + for attr in ("rtol", "atol", "norm"): + if getattr(solver.root_finder, attr) is use_stepsize_tol: # pyright: ignore + outs.append(getattr(x, attr)) + return tuple(outs) + + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): + solver = eqx.tree_at( + lambda s: _get_tols(s.root_finder), + solver, + _get_tols(stepsize_controller), + ) + else: + if len(_get_tols(solver.root_finder)) > 0: + raise ValueError( + "A fixed step size controller is being used alongside an implicit " + "solver, but the tolerances for the implicit solver have not been " + "specified. (Being unspecified is the default in Diffrax.)\n" + "The correct fix is almost always to use an adaptive step size " + "controller. For example " + "`diffrax.diffeqsolve(..., " + "stepsize_controller=diffrax.PIDController(rtol=..., atol=...))`. " + "In this case the same tolerances are used for the implicit " + "solver as are used to control the adaptive stepping.\n" + "(Note for advanced users: the tolerances for the implicit " + "solver can also be explicitly set instead. For example " + "`diffrax.diffeqsolve(..., solver=diffrax.Kvaerno5(root_finder=" + "diffrax.VeryChord(rtol=..., atol=..., " + "norm=optimistix.max_norm)))`. In this case the norm must also be " + "explicitly specified.)\n" + "Adaptive step size controllers are the preferred solution, as " + "sometimes the implicit solver may fail to converge, and in this " + "case an adaptive step size controller can reject the step and try " + "a smaller one, whilst with a fixed step size controller the " + "overall differential equation solve will simply fail." ) # Error checking diff --git a/diffrax/_root_finder/_with_tols.py b/diffrax/_root_finder/_with_tols.py index 49fcfd22..52779299 100644 --- a/diffrax/_root_finder/_with_tols.py +++ b/diffrax/_root_finder/_with_tols.py @@ -3,7 +3,15 @@ import optimistix as optx -use_stepsize_tol = object() +class _UseStepSizeTol: + def __repr__(self): + return ( + "" + ) + + +use_stepsize_tol = _UseStepSizeTol() def with_stepsize_controller_tols(cls: type[optx.AbstractRootFinder]): diff --git a/test/test_adjoint.py b/test/test_adjoint.py index b3ede4a6..f61a3f5e 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -363,3 +363,16 @@ def run(y0__args, adjoint): grads3 = run((y0, args), diffrax.RecursiveCheckpointAdjoint()) assert tree_allclose(grads1, grads2, rtol=1e-3, atol=1e-3) assert tree_allclose(grads1, grads3, rtol=1e-3, atol=1e-3) + + +def test_implicit_runge_kutta_direct_adjoint(): + diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: -y), + diffrax.Kvaerno5(), + 0, + 1, + 0.01, + 1.0, + adjoint=diffrax.DirectAdjoint(), + stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), + ) diff --git a/test/test_integrate.py b/test/test_integrate.py index 85885338..05a8576b 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -504,3 +504,16 @@ def vector_field(t, y, args): assert text == "static_made_jump=False static_result=None\n" finally: diffrax._integrate._PRINT_STATIC = False + + +def test_implicit_tol_error(): + msg = "the tolerances for the implicit solver have not been specified" + with pytest.raises(ValueError, match=msg): + diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: -y), + diffrax.Kvaerno5(), + 0, + 1, + 0.01, + 1.0, + )