Skip to content

Commit

Permalink
Solver parameter specification from config (#303)
Browse files Browse the repository at this point in the history
* atol and rtol are now specified through config

* adding SOLVER_MAX_STEPS to config and fixing tests

* forgot to change value back from testing
  • Loading branch information
arik-shurygin authored Dec 9, 2024
1 parent 08582bf commit f26261c
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 3 deletions.
3 changes: 3 additions & 0 deletions config/config_inferer_covid.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@
1.0,
1.0
],
"SOLVER_RELATIVE_TOLERANCE": 1e-5,
"SOLVER_ABSOLUTE_TOLERANCE": 1e-6,
"SOLVER_MAX_STEPS": 1e6,
"SEASONALITY_AMPLITUDE": 0.15,
"SEASONALITY_SECOND_WAVE": 0.5,
"SEASONALITY_SHIFT": 0,
Expand Down
3 changes: 3 additions & 0 deletions config/config_runner_covid.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
1.0,
1.0
],
"SOLVER_RELATIVE_TOLERANCE": 1e-5,
"SOLVER_ABSOLUTE_TOLERANCE": 1e-6,
"SOLVER_MAX_STEPS": 1e6,
"SEASONALITY_AMPLITUDE": 0.15,
"SEASONALITY_SECOND_WAVE": 0.0,
"SEASONALITY_SHIFT": 20,
Expand Down
3 changes: 3 additions & 0 deletions src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class AbstractParameters:
"SEASONALITY_SHIFT",
"MIN_HOMOLOGOUS_IMMUNITY",
"WANING_RATES",
"SOLVER_RELATIVE_TOLERANCE",
"SOLVER_ABSOLUTE_TOLERANCE",
"SOLVER_MAX_STEPS",
]

@abstractmethod
Expand Down
29 changes: 29 additions & 0 deletions src/dynode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,35 @@ class is accepted to modify/create the downstream parameters.
],
"type": float,
},
{
"name": "SOLVER_RELATIVE_TOLERANCE",
"validate": [
test_not_negative,
partial(test_type, tested_type=float),
# RTOL <= 1
lambda key, val: compare_geq(["1.0", key], [1.0, val]),
],
"type": float,
},
{
"name": "SOLVER_ABSOLUTE_TOLERANCE",
"validate": [
test_not_negative,
partial(test_type, tested_type=float),
# ATOL <= 1
lambda key, val: compare_geq(["1.0", key], [1.0, val]),
],
"type": float,
},
{
"name": "SOLVER_MAX_STEPS",
"validate": [
partial(test_type, tested_type=(int)),
# STEPS >= 1
lambda key, val: compare_geq([key, "1"], [val, 1]),
],
"type": int,
},
{
"name": "STRAIN_R0s",
"validate": [
Expand Down
6 changes: 3 additions & 3 deletions src/dynode/mechanistic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def run(
else None
)
stepsize_controller = PIDController(
rtol=1e-5,
atol=1e-6,
rtol=args.get("SOLVER_RELATIVE_TOLERANCE", 1e-5),
atol=args.get("SOLVER_ABSOLUTE_TOLERANCE", 1e-6),
jump_ts=jump_ts,
)

Expand All @@ -105,6 +105,6 @@ def run(
stepsize_controller=stepsize_controller,
saveat=saveat,
# higher for large time scales / rapid changes
max_steps=int(1e6),
max_steps=args.get("SOLVER_MAX_STEPS", int(1e6)),
)
return solution
3 changes: 3 additions & 0 deletions tests/test_config_inferer.json
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
1.0,
1.0
],
"SOLVER_RELATIVE_TOLERANCE": 1e-5,
"SOLVER_ABSOLUTE_TOLERANCE": 1e-6,
"SOLVER_MAX_STEPS": 1e6,
"SEASONALITY_AMPLITUDE": 0.0,
"SEASONALITY_SECOND_WAVE": 0.5,
"SEASONALITY_SHIFT": 0,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_config_runner.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
1.0,
1.0
],
"SOLVER_RELATIVE_TOLERANCE": 1e-5,
"SOLVER_ABSOLUTE_TOLERANCE": 1e-6,
"SOLVER_MAX_STEPS": 1e6,
"SEASONALITY_AMPLITUDE": 0.0,
"SEASONALITY_SECOND_WAVE": 0.5,
"SEASONALITY_SHIFT": 0,
Expand Down

0 comments on commit f26261c

Please sign in to comment.