From f26261c5ea2170741180a83c3bb03e3d0b62b927 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin <39861882+arik-shurygin@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:55:08 -0800 Subject: [PATCH] Solver parameter specification from config (#303) * atol and rtol are now specified through config * adding SOLVER_MAX_STEPS to config and fixing tests * forgot to change value back from testing --- config/config_inferer_covid.json | 3 +++ config/config_runner_covid.json | 3 +++ src/dynode/abstract_parameters.py | 3 +++ src/dynode/config.py | 29 +++++++++++++++++++++++++++++ src/dynode/mechanistic_runner.py | 6 +++--- tests/test_config_inferer.json | 3 +++ tests/test_config_runner.json | 3 +++ 7 files changed, 47 insertions(+), 3 deletions(-) diff --git a/config/config_inferer_covid.json b/config/config_inferer_covid.json index 7aa01c4f..03f9bfff 100644 --- a/config/config_inferer_covid.json +++ b/config/config_inferer_covid.json @@ -116,6 +116,9 @@ 1.0, 1.0 ], + "SOLVER_RELATIVE_TOLERANCE": 1e-5, + "SOLVER_ABSOLUTE_TOLERANCE": 1e-6, + "SOLVER_MAX_STEPS": 1e6, "SEASONALITY_AMPLITUDE": 0.15, "SEASONALITY_SECOND_WAVE": 0.5, "SEASONALITY_SHIFT": 0, diff --git a/config/config_runner_covid.json b/config/config_runner_covid.json index b410e477..345aebf5 100644 --- a/config/config_runner_covid.json +++ b/config/config_runner_covid.json @@ -80,6 +80,9 @@ 1.0, 1.0 ], + "SOLVER_RELATIVE_TOLERANCE": 1e-5, + "SOLVER_ABSOLUTE_TOLERANCE": 1e-6, + "SOLVER_MAX_STEPS": 1e6, "SEASONALITY_AMPLITUDE": 0.15, "SEASONALITY_SECOND_WAVE": 0.0, "SEASONALITY_SHIFT": 20, diff --git a/src/dynode/abstract_parameters.py b/src/dynode/abstract_parameters.py index 488d8bf2..4744007f 100644 --- a/src/dynode/abstract_parameters.py +++ b/src/dynode/abstract_parameters.py @@ -55,6 +55,9 @@ class AbstractParameters: "SEASONALITY_SHIFT", "MIN_HOMOLOGOUS_IMMUNITY", "WANING_RATES", + "SOLVER_RELATIVE_TOLERANCE", + "SOLVER_ABSOLUTE_TOLERANCE", + "SOLVER_MAX_STEPS", ] @abstractmethod diff --git a/src/dynode/config.py b/src/dynode/config.py index 402dbe65..cc2a8f67 100644 --- a/src/dynode/config.py +++ b/src/dynode/config.py @@ -741,6 +741,35 @@ class is accepted to modify/create the downstream parameters. ], "type": float, }, + { + "name": "SOLVER_RELATIVE_TOLERANCE", + "validate": [ + test_not_negative, + partial(test_type, tested_type=float), + # RTOL <= 1 + lambda key, val: compare_geq(["1.0", key], [1.0, val]), + ], + "type": float, + }, + { + "name": "SOLVER_ABSOLUTE_TOLERANCE", + "validate": [ + test_not_negative, + partial(test_type, tested_type=float), + # ATOL <= 1 + lambda key, val: compare_geq(["1.0", key], [1.0, val]), + ], + "type": float, + }, + { + "name": "SOLVER_MAX_STEPS", + "validate": [ + partial(test_type, tested_type=(int)), + # STEPS >= 1 + lambda key, val: compare_geq([key, "1"], [val, 1]), + ], + "type": int, + }, { "name": "STRAIN_R0s", "validate": [ diff --git a/src/dynode/mechanistic_runner.py b/src/dynode/mechanistic_runner.py index 55f8edbb..47fd3815 100644 --- a/src/dynode/mechanistic_runner.py +++ b/src/dynode/mechanistic_runner.py @@ -89,8 +89,8 @@ def run( else None ) stepsize_controller = PIDController( - rtol=1e-5, - atol=1e-6, + rtol=args.get("SOLVER_RELATIVE_TOLERANCE", 1e-5), + atol=args.get("SOLVER_ABSOLUTE_TOLERANCE", 1e-6), jump_ts=jump_ts, ) @@ -105,6 +105,6 @@ def run( stepsize_controller=stepsize_controller, saveat=saveat, # higher for large time scales / rapid changes - max_steps=int(1e6), + max_steps=args.get("SOLVER_MAX_STEPS", int(1e6)), ) return solution diff --git a/tests/test_config_inferer.json b/tests/test_config_inferer.json index 2bed8602..9e2aa98a 100644 --- a/tests/test_config_inferer.json +++ b/tests/test_config_inferer.json @@ -101,6 +101,9 @@ 1.0, 1.0 ], + "SOLVER_RELATIVE_TOLERANCE": 1e-5, + "SOLVER_ABSOLUTE_TOLERANCE": 1e-6, + "SOLVER_MAX_STEPS": 1e6, "SEASONALITY_AMPLITUDE": 0.0, "SEASONALITY_SECOND_WAVE": 0.5, "SEASONALITY_SHIFT": 0, diff --git a/tests/test_config_runner.json b/tests/test_config_runner.json index 7c7ae13f..8ec69d86 100644 --- a/tests/test_config_runner.json +++ b/tests/test_config_runner.json @@ -80,6 +80,9 @@ 1.0, 1.0 ], + "SOLVER_RELATIVE_TOLERANCE": 1e-5, + "SOLVER_ABSOLUTE_TOLERANCE": 1e-6, + "SOLVER_MAX_STEPS": 1e6, "SEASONALITY_AMPLITUDE": 0.0, "SEASONALITY_SECOND_WAVE": 0.5, "SEASONALITY_SHIFT": 0,