diff --git a/diffrax/__init__.py b/diffrax/__init__.py index beaafdeb..054111f3 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -81,4 +81,4 @@ ) -__version__ = "0.0.3" +__version__ = "0.0.4" diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 4b1b8ee1..9aad5d7d 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -81,16 +81,25 @@ def __post_init__(self): ) -_SolverState = Optional[Tuple[Optional[PyTree], Scalar]] +_SolverState = Optional[PyTree] # TODO: examine termination criterion for Newton iteration -# TODO: consider dividing by diagonal and control -# TODO: replace ki with ki=(zi + predictor), where this relation defines some zi, and +# TODO: replace fi with fi=(zi + predictor), where this relation defines some zi, and # iterate to find zi, using zi=0 as the predictor. This should give better # numerical behaviour since the iteration is close to 0. (Although we have # multiplied by the increment of the control, i.e. dt, which is small...) -def _implicit_relation(ki, nonlinear_solve_args): +def _implicit_relation_f(fi, nonlinear_solve_args): + diagonal, vf, prod, ti, yi_partial, args, control = nonlinear_solve_args + diff = ( + vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω + - fi**ω + ).ω + return diff + + +# TODO: consider dividing by diagonal and control +def _implicit_relation_k(ki, nonlinear_solve_args): # c.f: # https://github.com/SciML/DiffEqDevMaterials/blob/master/newton/output/main.pdf # (Bearing in mind that our ki is dt times smaller than theirs.) @@ -113,8 +122,20 @@ def tableau(self) -> ButcherTableau: @abc.abstractmethod def _recompute_jac(self, i: int) -> bool: + """Called on the i'th stage for all i. Used to determine when the Jacobian + should be recomputed or not. + """ pass + def func_for_init( + self, + terms: AbstractTerm, + t0: Scalar, + y0: PyTree, + args: PyTree, + ) -> PyTree: + return terms.func_for_init(t0, y0, args) + def init( self, terms: AbstractTerm, @@ -123,11 +144,10 @@ def init( y0: PyTree, args: PyTree, ) -> _SolverState: - if self.tableau.fsal: - control = terms.contr(t0, t1) - k0 = terms.vf_prod(t0, y0, args, control) - dt = t1 - t0 - return k0, dt + vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) + fsal = self.tableau.fsal and not vf_expensive + if fsal: + return terms.vf(t0, y0, args) else: return None @@ -142,18 +162,75 @@ def step( made_jump: Bool, ) -> Tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: + # + # Some Runge--Kutta methods have special structure that we can use to improve + # efficiency. + # + # The famous one is FSAL; "first same as last". That is, the final evaluation + # of the vector field on the previous step is the same as the first evaluation + # on the subsequent step. We can reuse it and save an evaluation. + # However note that this requires saving a vf evaluation, not a + # vf-control-product. (This comes up when we have a different control on the + # next step, e.g. as with adaptive step sizes, or with SDEs.) + # As such we disable this is a vf is expensive and a vf-control-product is + # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. + # For this SDE, the vf-control product is a vector-Jacobian product, which is + # notably cheaper than evaluating a full Jacobian.) + # + # Next we have SSAL; "solution same as last". That is, the output of the step + # has already been calculated during the internal stage calculations. We can + # reuse those and save a dot product. + # + # Finally we have a choice whether to save and work with vector field + # evaluations (fs), or to save and work with (vector field)-control products + # (ks). + # The former is needed for implicit FSAL solvers: they need to obtain the + # final f1 for the FSAL property, which means they need to do the implicit + # solve in vf-space rather than (vf-control-product)-space, which means they + # need to use `fs` to predict the initial point for the root finding operation. + # Meanwhile the latter is needed when solving optimise-then-discretise adjoint + # SDEs, for which vector field evaluations are prohibitively expensive, and we + # must necessarily work only with the (much cheap) vf-control-products. (In + # this case this is the difference between computing a Jacobian and computing a + # vector-Jacobian product.) + # For other probles, we choose to use `ks`. This doesn't have a strong + # rationale although it does have some minor efficiency points in its favour, + # e.g. we need `ks` to perform dense interpolation if needed. + # + _vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) + _implicit_later_stages = self.tableau.a_diagonal is not None and any( + self.tableau.a_diagonal[1:] != 0 + ) + fsal = self.tableau.fsal and not _vf_expensive + ssal = self.tableau.ssal + if _implicit_later_stages and fsal: + use_fs = True + elif _vf_expensive: + use_fs = False + else: # Choice not as important here; we use ks for minor efficiency reasons. + use_fs = False + + # + # Initialise values. Evaluate the first stage if not FSAL. + # + control = terms.contr(t0, t1) dt = t1 - t0 - if self.tableau.fsal: - k0, prev_dt = solver_state - k0 = lax.cond( - made_jump, - lambda _: terms.vf_prod(t0, y0, args, control), - lambda _: (k0**ω * (dt / prev_dt)).ω, - None, - ) - jac = None + if fsal: + if use_fs: + f0 = solver_state + else: + f0 = solver_state + k0 = lax.cond( + made_jump, + lambda _: terms.vf_prod(t0, y0, args, control), + lambda _: terms.prod(f0, control), + None, + ) + del f0 + jac_f = None + jac_k = None result = RESULTS.successful else: if self.tableau.a_diagonal is None: @@ -164,16 +241,44 @@ def step( t0_ = t1 else: t0_ = t0 + self.tableau.diagonal[0] * dt - k0, jac, result = self._eval_stage( - terms, 0, t0_, y0, args, control, jac=None, k=None + f0, k0, jac_f, jac_k, result = self._eval_stage( + terms, + 0, + t0_, + y0, + args, + control, + jac_f=None, + jac_k=None, + fs=None, + ks=None, + return_fi=use_fs, + return_ki=not use_fs, ) + if use_fs: + assert k0 is None + del k0 + else: + assert f0 is None + del f0 + + # + # Initialise `fs` or `ks` as a place to store the stage evaluations. + # - # Note that our `k` is (for an ODE) `dt` times smaller than the usual - # implementation (e.g. what you see in torchdiffeq or in the reference texts). - # This is because of our vector-field-control approach. lentime = (len(self.tableau.c) + 1,) - k = jax.tree_map(lambda y: jnp.empty(lentime + jnp.shape(y)), y0) - k = (k**ω).at[0].set(k0**ω).ω + if use_fs: + fs = jax.tree_map(lambda f: jnp.empty(lentime + jnp.shape(f)), f0) + fs = (fs**ω).at[0].set(f0**ω).ω + ks = None + else: + fs = None + ks = jax.tree_map(lambda k: jnp.empty(lentime + jnp.shape(k)), k0) + ks = (ks**ω).at[0].set(k0**ω).ω + + # + # Iterate through the stages + # for i, (a_i, c_i) in enumerate(zip(self.tableau.a_lower, self.tableau.c)): if c_i == 1: @@ -181,75 +286,185 @@ def step( ti = t1 else: ti = t0 + c_i * dt - yi_partial = (y0**ω + vector_tree_dot(a_i, ω(k)[: i + 1].ω) ** ω).ω - ki, jac, new_result = self._eval_stage( - terms, i + 1, ti, yi_partial, args, control, jac, k + if use_fs: + increment = vector_tree_dot(a_i, ω(fs)[: i + 1].ω) + increment = terms.prod(increment, control) + else: + increment = vector_tree_dot(a_i, ω(ks)[: i + 1].ω) + yi_partial = (y0**ω + increment**ω).ω + last_iteration = i == len(self.tableau.a_lower) - 1 + return_fi = use_fs or (fsal and last_iteration) + return_ki = not use_fs + fi, ki, jac_f, jac_k, new_result = self._eval_stage( + terms, + i + 1, + ti, + yi_partial, + args, + control, + jac_f, + jac_k, + fs, + ks, + return_fi, + return_ki, ) + if not return_fi: + assert fi is None + del fi + if use_fs: + assert ki is None + del ki result = jnp.where(result == RESULTS.successful, new_result, result) - # TODO: fast path to skip the rest of the stages if result is not successful - k = ω(k).at[i + 1].set(ω(ki)).ω + if use_fs: + fs = ω(fs).at[i + 1].set(ω(fi)).ω + else: + ks = ω(ks).at[i + 1].set(ω(ki)).ω + + # + # Compute step output + # - if self.tableau.ssal: + if ssal: y1 = yi_partial else: - y1 = (y0**ω + vector_tree_dot(self.tableau.b_sol, k) ** ω).ω - if self.tableau.fsal: - k1 = (k**ω)[-1].ω + if use_fs: + increment = vector_tree_dot(self.tableau.b_sol, fs) + increment = terms.prod(increment, control) + else: + increment = vector_tree_dot(self.tableau.b_sol, ks) + y1 = (y0**ω + increment**ω).ω + + # + # Compute error estimate + # + + if use_fs: + y_error = vector_tree_dot(self.tableau.b_error, fs) + y_error = terms.prod(y_error, control) else: - k1 = None - y_error = vector_tree_dot(self.tableau.b_error, k) + y_error = vector_tree_dot(self.tableau.b_error, ks) y_error = jax.tree_map( lambda _y_error: jnp.where(result == RESULTS.successful, _y_error, jnp.inf), y_error, - ) - dense_info = dict(y0=y0, y1=y1, k=k) - if self.tableau.fsal: - solver_state = (k1, dt) + ) # i.e. an implicit step failed to converge + + # + # Compute dense info + # + + if use_fs: + if fs is None: + # Edge case for diffeqsolve(y0=None) + ks = None + else: + ks = jax.vmap(lambda f: terms.prod(f, control))(fs) + dense_info = dict(y0=y0, y1=y1, k=ks) + + # + # Compute next solver state + # + + if fsal: + solver_state = fi else: solver_state = None + return y1, y_error, dense_info, solver_state, result - def func_for_init( + def _eval_stage( self, - terms: AbstractTerm, - t0: Scalar, - y0: PyTree, - args: PyTree, - ) -> PyTree: - return terms.func_for_init(t0, y0, args) - - def _eval_stage(self, terms, i, ti, yi_partial, args, control, jac, k): + terms, + i, + ti, + yi_partial, + args, + control, + jac_f, + jac_k, + fs, + ks, + return_fi, + return_ki, + ): + assert return_fi or return_ki if self.tableau.a_diagonal is None: diagonal = 0 else: diagonal = self.tableau.a_diagonal[i] if diagonal == 0: # Explicit stage - ki = terms.vf_prod(ti, yi_partial, args, control) - return ki, jac, RESULTS.successful + if return_fi: + fi = terms.vf(ti, yi_partial, args) + if return_ki: + ki = terms.prod(fi, control) + else: + ki = None + else: + fi = None + if return_ki: + ki = terms.vf_prod(ti, yi_partial, args, control) + else: + assert False + return fi, ki, jac_f, jac_k, RESULTS.successful else: # Implicit stage - if i == 0: - # Implicit first stage. Make an extra function evaluation to use as a - # predictor for the solution to the first stage. - ki_pred = terms.vf_prod(ti, yi_partial, args, control) - else: - ki_pred = vector_tree_dot(self.tableau.a_predictor[i - 1], ω(k)[:i].ω) - if self._recompute_jac(i): - jac = self.nonlinear_solver.jac( - _implicit_relation, - ki_pred, - (diagonal, terms.vf_prod, ti, yi_partial, args, control), + if return_fi: + if i == 0: + # Implicit first stage. Make an extra function evaluation to use as + # a predictor for the solution to the first stage. + fi_pred = terms.vf(ti, yi_partial, args) + else: + fi_pred = vector_tree_dot( + self.tableau.a_predictor[i - 1], ω(fs)[:i].ω + ) + if self._recompute_jac(i): + jac_f = self.nonlinear_solver.jac( + _implicit_relation_f, + fi_pred, + (diagonal, terms.vf, terms.prod, ti, yi_partial, args, control), + ) + assert jac_f is not None + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_f, + fi_pred, + (diagonal, terms.vf, terms.prod, ti, yi_partial, args, control), + jac_f, ) - assert jac is not None - nonlinear_sol = self.nonlinear_solver( - _implicit_relation, - ki_pred, - (diagonal, terms.vf_prod, ti, yi_partial, args, control), - jac, - ) - ki = nonlinear_sol.root - return ki, jac, nonlinear_sol.result + fi = nonlinear_sol.root + if return_ki: + ki = terms.prod(fi, control) + else: + ki = None + return fi, ki, jac_f, jac_k, nonlinear_sol.result + else: + if return_ki: + if i == 0: + # Implicit first stage. Make an extra function evaluation to + # use as a predictor for the solution to the first stage. + ki_pred = terms.vf_prod(ti, yi_partial, args, control) + else: + ki_pred = vector_tree_dot( + self.tableau.a_predictor[i - 1], ω(ks)[:i].ω + ) + if self._recompute_jac(i): + jac_k = self.nonlinear_solver.jac( + _implicit_relation_k, + ki_pred, + (diagonal, terms.vf_prod, ti, yi_partial, args, control), + ) + assert jac_k is not None + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_k, + ki_pred, + (diagonal, terms.vf_prod, ti, yi_partial, args, control), + jac_k, + ) + fi = None + ki = nonlinear_sol.root + return fi, ki, jac_f, jac_k, nonlinear_sol.result + else: + assert False class AbstractERK(AbstractRungeKutta): diff --git a/diffrax/term.py b/diffrax/term.py index ec8fe571..d9281dfb 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -139,7 +139,7 @@ def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree # This is a pinhole break in our vector-field/control abstraction. # Everywhere else we get to evaluate over some interval, which allows us to # evaluate our control over that interval. However to select the initial point in - # an adapative step size scheme, the standard heuristic is to start by making + # an adaptive step size scheme, the standard heuristic is to start by making # evaluations at just the initial point -- no intervals involved. def func_for_init(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: """This is a special-cased version of [`diffrax.AbstractTerm.vf`][]. @@ -153,6 +153,7 @@ def func_for_init(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: See [`diffrax.AbstractSolver.func_for_init`][]. """ + # Heuristic for whether it's safe to select an initial step automatically. vf = self.vf(t, y, args) flat_vf, tree_vf = jax.tree_flatten(vf) flat_y, tree_y = jax.tree_flatten(y) @@ -168,6 +169,20 @@ def func_for_init(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: else: return vf + def is_vf_expensive( + self, + t0: Scalar, + t1: Scalar, + y: Tuple[PyTree, PyTree, PyTree, PyTree], + args: PyTree, + ) -> bool: + """Specifies whether evaluating the vector field is "expensive", in the + specific sense that it is cheaper to evaluate `vf_prod` twice than `vf` once. + + Some solvers use this to change their behaviour, so as to act more efficiently. + """ + return False + class ODETerm(AbstractTerm): r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term @@ -392,11 +407,24 @@ def prod(self, vf: PyTree, control: PyTree) -> PyTree: class AdjointTerm(AbstractTerm): term: AbstractTerm + def is_vf_expensive( + self, + t0: Scalar, + t1: Scalar, + y: Tuple[PyTree, PyTree, PyTree, PyTree], + args: PyTree, + ) -> bool: + control = self.contr(t0, t1) + if sum(c.size for c in jax.tree_leaves(control)) in (0, 1): + return False + else: + return True + def vf( self, t: Scalar, y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree ) -> PyTree: # We compute the vector field via `self.vf_prod`. We could also do it manually, - # but this is relatively painless.# + # but this is relatively painless. # # This can be done because `self.vf_prod` is linear in `control`. As such we # can obtain just the vector field component by representing this linear @@ -443,7 +471,7 @@ def _fn(_control): if jax.tree_structure(None) in (vf_prod_tree, control_tree): # An unusual/not-useful edge case to handle. raise NotImplementedError( - "`AdjointTerm` not implemented for `None` controls or states." + "`AdjointTerm.vf` not implemented for `None` controls or states." ) return jax.tree_transpose(vf_prod_tree, control_tree, jac) diff --git a/docs/api/terms.md b/docs/api/terms.md index 4418ac8c..73316cbb 100644 --- a/docs/api/terms.md +++ b/docs/api/terms.md @@ -35,6 +35,7 @@ The very first argument to [`diffrax.diffeqsolve`][] should be some PyTree of te - contr - prod - vf_prod + - is_vf_expensive --- diff --git a/test/helpers.py b/test/helpers.py index 119cb5e5..a060c07d 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -43,8 +43,8 @@ def random_pytree(key, treedef): treedefs = [ jax.tree_structure(x) for x in ( - None, 0, + None, {"a": [0, 0], "b": 0}, ) ] diff --git a/test/test_integrate.py b/test/test_integrate.py index abb368b5..ca589a9a 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -13,22 +13,42 @@ from helpers import all_ode_solvers, random_pytree, shaped_allclose, treedefs +def _all_pairs(*args): + defaults = [arg["default"] for arg in args] + yield defaults + for i in range(len(args)): + for opt in args[i]["opts"]: + opts = defaults.copy() + opts[i] = opt + yield opts + for i in range(len(args)): + for j in range(i + 1, len(args)): + for opt1 in args[i]["opts"]: + for opt2 in args[j]["opts"]: + opts = defaults.copy() + opts[i] = opt1 + opts[j] = opt2 + yield opts + + @pytest.mark.parametrize( - "solver_ctr", - ( - diffrax.Euler, - diffrax.LeapfrogMidpoint, - diffrax.ReversibleHeun, - diffrax.Tsit5, - diffrax.ImplicitEuler, - diffrax.Kvaerno3, + "solver_ctr,t_dtype,treedef,stepsize_controller", + _all_pairs( + dict( + default=diffrax.Euler, + opts=( + diffrax.LeapfrogMidpoint, + diffrax.ReversibleHeun, + diffrax.Tsit5, + diffrax.ImplicitEuler, + diffrax.Kvaerno3, + ), + ), + dict(default=jnp.float32, opts=(int, float, jnp.int32)), + dict(default=treedefs[0], opts=treedefs[1:]), + dict(default=diffrax.ConstantStepSize(), opts=(diffrax.PIDController(),)), ), ) -@pytest.mark.parametrize("t_dtype", (int, float, jnp.int32, jnp.float32)) -@pytest.mark.parametrize("treedef", treedefs) -@pytest.mark.parametrize( - "stepsize_controller", (diffrax.ConstantStepSize(), diffrax.PIDController(atol=1e2)) -) def test_basic(solver_ctr, t_dtype, treedef, stepsize_controller, getkey): if not issubclass(solver_ctr, diffrax.AbstractAdaptiveSolver) and isinstance( stepsize_controller, diffrax.PIDController @@ -40,25 +60,25 @@ def f(t, y, args): if t_dtype is int: t0 = 0 - t1 = 2 - dt0 = 1 + t1 = 1 + dt0 = 0.01 elif t_dtype is float: t0 = 0.0 - t1 = 2.0 - dt0 = 1.0 + t1 = 1.0 + dt0 = 0.01 elif t_dtype is jnp.int32: t0 = jnp.array(0) - t1 = jnp.array(2) - dt0 = jnp.array(1) + t1 = jnp.array(1) + dt0 = jnp.array(0.01) elif t_dtype is jnp.float32: t0 = jnp.array(0.0) - t1 = jnp.array(2.0) - dt0 = jnp.array(1.0) + t1 = jnp.array(1.0) + dt0 = jnp.array(0.01) else: raise ValueError y0 = random_pytree(getkey(), treedef) try: - diffrax.diffeqsolve( + sol = diffrax.diffeqsolve( diffrax.ODETerm(f), solver_ctr(), t0, @@ -76,6 +96,9 @@ def f(t, y, args): pass else: raise + y1 = sol.ys + true_y1 = jax.tree_map(lambda x: (x * math.exp(-1))[None], y0) + assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2) @pytest.mark.parametrize("solver_ctr", all_ode_solvers) @@ -278,32 +301,6 @@ def f(t, y, args): assert shaped_allclose(sol1.derivative(ti), -sol2.derivative(-ti)) -@pytest.mark.parametrize( - "solver_ctr,stepsize_controller,dt0", - ( - (diffrax.Tsit5, diffrax.ConstantStepSize(), 0.3), - (diffrax.Tsit5, diffrax.PIDController(rtol=1e-8, atol=1e-8), None), - (diffrax.Kvaerno3, diffrax.PIDController(rtol=1e-8, atol=1e-8), None), - ), -) -@pytest.mark.parametrize("treedef", treedefs) -def test_pytree_state(solver_ctr, stepsize_controller, dt0, treedef, getkey): - term = diffrax.ODETerm(lambda t, y, args: jax.tree_map(operator.neg, y)) - y0 = random_pytree(getkey(), treedef) - sol = diffrax.diffeqsolve( - term, - solver=solver_ctr(), - t0=0, - t1=1, - dt0=dt0, - y0=y0, - stepsize_controller=stepsize_controller, - ) - y1 = sol.ys - true_y1 = jax.tree_map(lambda x: (x * math.exp(-1))[None], y0) - assert shaped_allclose(y1, true_y1) - - def test_semi_implicit_euler(): term1 = diffrax.ODETerm(lambda t, y, args: -y) term2 = diffrax.ODETerm(lambda t, y, args: y)