Skip to content

Commit

Permalink
Fix mypy errors (#306)
Browse files Browse the repository at this point in the history
* Initial mypy.ini

* Fix abstract_parameters.py

* Fix abstract_parameters.py

* Fix solution_interpreter

* Fixed __init__.py

* Fix mechanistic_inferer

* Suppress error in example_end_to_end_run.py

* Fix example_end_to_end_run.py

* Fix test_inferer.py

* fix test_utils.py

* Fix dynode_runner

* Fix covid_sero_initializer

* Fix sim_data_to_sero_generator

* checkpoint

* Checkpoint

* Fix vis_utils.py

* Fix seip_model

* First clean mypy run

* Fix indexes in dynode_utils

* Fix indexes in dynode_utils

* ignore error

* Pre-commit

* Add CI for mypy (#308)

* fixing up type hints in comments and fixing some ignored errors

* fixing mypy hint for identify_distribution_indexes

* fixing one lingering mypy error

* fixing mypy error with new end to end run

---------

Co-authored-by: EKR <ekr@rtfm.com>
Co-authored-by: Ariel Shurygin <arik.shurygin@gmail.com>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 55da1a1 commit 3690865
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 82 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Run MyPy

on:
pull_request:
push:
branches: [main]

jobs:
pytest:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install global dependencies
run: |
pip install poetry
- name: Set up
run: |
poetry install
- name: Run myPy
run: |
poetry run mypy .
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
disable_error_code = attr-defined,import-untyped
plugins = numpy.typing.mypy_plugin
22 changes: 12 additions & 10 deletions data_manipulation_scripts/sim_data_to_sero_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@
how="left",
).drop("p_id", axis=1)

infection_history_by_infected.columns = [
"pid",
"age",
"strains",
"last_infected_date",
"last_infectious_start_date",
"last_infectious_end_date",
"num_doses",
"last_vax_date",
]
infection_history_by_infected.columns = pd.Index(
[
"pid",
"age",
"strains",
"last_infected_date",
"last_infectious_start_date",
"last_infectious_end_date",
"num_doses",
"last_vax_date",
]
)

### FILL NA VALUES WITH GOOD DEFAULTS AND SET TO INT FROM FLOAT.
infection_history_by_infected["strains"] = (
Expand Down
10 changes: 6 additions & 4 deletions examples/example_end_to_end_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import argparse
import os

import jax.numpy as jnp
import numpy as np

# the different segments of code responsible for runing the model
Expand All @@ -38,19 +39,20 @@


class ExampleDynodeRunner(AbstractDynodeRunner):
def process_state(self, _: str, infer: bool = False):
def process_state(self, state: str, **kwargs):
"""An example of a method used to process a single state using DynODE
In this case the example configs are built around US data, so we are
running the whole US as a single entity.
Parameters
----------
_ : str
state : str
state USPS, ignored in this specific example
infer : bool, optional
whether or not the user of this example script wants to run inference,
by default False
"""
infer = bool(kwargs["infer"])
# step 1: define your paths
config_path = "examples/config/"
# global_config include definitions such as age bin bounds and strain definitions
Expand Down Expand Up @@ -111,7 +113,7 @@ def process_state(self, _: str, infer: bool = False):
print("Fitting to synthetic hospitalization data: ")
# this will print a summary of the inferred variables
# those distributions in the Config are now posteriors
inferer.infer(synthetic_observed_hospitalizations)
inferer.infer(jnp.array(synthetic_observed_hospitalizations))
print("saving a suite of inference visualizations ")
self.save_inference_timelines(
inferer, "local_inference_timeseries.csv"
Expand Down Expand Up @@ -152,4 +154,4 @@ def process_state(self, _: str, infer: bool = False):
os.mkdir("output")

runner = ExampleDynodeRunner("output/")
runner.process_state("USA", infer)
runner.process_state("USA", infer=infer)
30 changes: 28 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ docker = "^7.1.0"
bayeux-ml = "^0.1.14"


[tool.poetry.group.dev.dependencies]
pandas-stubs = "^2.2.3.241126"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
20 changes: 10 additions & 10 deletions src/dynode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@

# Defines all the different modules able to be imported from src
__all__ = [
AbstractParameters,
AbstractInitializer,
CovidSeroInitializer,
MechanisticInferer,
MechanisticRunner,
StaticValueParameters,
utils,
Config,
vis_utils,
AbstractDynodeRunner,
"AbstractParameters",
"AbstractInitializer",
"CovidSeroInitializer",
"MechanisticInferer",
"MechanisticRunner",
"StaticValueParameters",
"utils",
"Config",
"vis_utils",
"AbstractDynodeRunner",
]
8 changes: 5 additions & 3 deletions src/dynode/abstract_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import Any

from . import SEIC_Compartments, utils

Expand All @@ -18,10 +19,10 @@ class AbstractInitializer(ABC):
"""

@abstractmethod
def __init__(self, initializer_config):
def __init__(self, initializer_config) -> None:
# add these for mypy
self.INITIAL_STATE: SEIC_Compartments = tuple()
self.config = {}
self.INITIAL_STATE: SEIC_Compartments | None = None
self.config: Any = {}
pass

def get_initial_state(
Expand All @@ -30,6 +31,7 @@ def get_initial_state(
"""
Returns the initial state of the model as defined by the child class in __init__
"""
assert self.INITIAL_STATE is not None
return self.INITIAL_STATE

def load_initial_population_fractions(self) -> None:
Expand Down
13 changes: 9 additions & 4 deletions src/dynode/abstract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ class AbstractParameters:
]

@abstractmethod
def __init__(self, parameters_config):
def __init__(self) -> None:
# add these for mypy type checker
self.config: Config = {}
self.INITIAL_STATE: SEIC_Compartments = tuple()
pass
self.config = Config("{}")
initial_state = tuple(
[jnp.arange(0), jnp.arange(0), jnp.arange(0), jnp.arange(0)]
)
assert len(initial_state) == 4
self.INITIAL_STATE: SEIC_Compartments = initial_state

def _solve_runner(
self, parameters: dict, tf: int, runner: MechanisticRunner
Expand Down Expand Up @@ -429,6 +432,7 @@ def seasonality(
k = 2 * jnp.pi / 365.0
# for a closed form solution to the combination of both cosine curves
# we must split along a boundary of second (summer) wave values
assert not isinstance(seasonality_second_wave, complex)
cos_val = jnp.where(
seasonality_second_wave > 0.2,
(seasonality_second_wave - 1)
Expand Down Expand Up @@ -721,4 +725,5 @@ def scale_initial_infections(
)
]
)
assert len(initial_state) == 4
return initial_state
4 changes: 3 additions & 1 deletion src/dynode/covid_sero_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def load_immune_history_via_serological_data(self) -> None:
# read in csv, do a bunch of data cleaning
sero_df = pd.read_csv(sero_path)
sero_df["type"] = sero_df["type"].fillna("None")
sero_df.columns = ["age", "hist", "vax", "0", "1", "2", "3", "4"]
sero_df.columns = pd.Index(
["age", "hist", "vax", "0", "1", "2", "3", "4"]
)
melted_df = pd.melt(
sero_df,
id_vars=["age", "hist", "vax"],
Expand Down
26 changes: 15 additions & 11 deletions src/dynode/dynode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional

import numpy as np
import pandas as pd # type: ignore
Expand Down Expand Up @@ -106,7 +106,7 @@ def _pad_fn(series, index_len, pad):

if len(series) < index_len:
return _pad_fn(series, index_len, pad)
return series
return np.array(series)

def _get_vaccination_timeseries(
self, vaccination_func, num_days_predicted
Expand Down Expand Up @@ -367,15 +367,15 @@ def _save_samples(self, samples, save_path):

def save_mcmc_chains_plot(
self,
samples: dict[str : list : np.ndarray],
samples: dict[str, list | np.ndarray],
save_filename: str = "mcmc_chains.png",
plot_kwargs: dict = {},
):
"""Saves a plot mapping the MCMC chains of the inference job
Parameters
----------
samples : dict[str: list | np.ndarray]
samples : dict[str, list | np.ndarray]
a dictionary (usually loaded from the checkpoint.json file) containing
the sampled posteriors for each chain in the shape
(num_chains, num_samples). All parameters generated with numpyro.plate
Expand All @@ -394,15 +394,15 @@ def save_mcmc_chains_plot(

def save_correlation_pairs_plot(
self,
samples: dict[str : list : np.ndarray],
samples: dict[str, list | np.ndarray],
save_filename: str = "mcmc_correlations.png",
plot_kwargs: dict = {},
):
"""_summary_
Parameters
----------
samples : dict[str: list | np.ndarray]
samples : dict[str, list | np.ndarray]
a dictionary (usually loaded from the checkpoint.json file) containing
the sampled posteriors for each chain in the shape
(num_chains, num_samples). All parameters generated with numpyro.plate
Expand Down Expand Up @@ -543,8 +543,8 @@ def save_inference_timelines(
inferer: MechanisticInferer,
timeline_filename: str = "azure_visualizer_timeline.csv",
particles_saved=1,
extra_timelines: pd.DataFrame = None,
tf: Union[int, None] = None,
extra_timelines: None | pd.DataFrame = None,
tf: int | None = None,
external_particle: dict[str, Array] = {},
verbose: bool = False,
) -> str:
Expand Down Expand Up @@ -626,12 +626,16 @@ def save_inference_timelines(
for (chain, particle), sol_dct in posteriors.items():
# content of `sol_dct` depends on return value of inferer.likelihood func
infection_timeline: Solution = sol_dct["solution"]
hospitalizations: Array = sol_dct["hospitalizations"]
static_parameters: dict[str, Array] = sol_dct["parameters"]
hospitalizations_tmp = sol_dct["hospitalizations"]
assert isinstance(hospitalizations_tmp, Array)
hospitalizations: Array = hospitalizations_tmp
parameters_tmp = sol_dct["parameters"]
assert isinstance(parameters_tmp, dict)
static_parameters: dict[str, Array] = parameters_tmp
# spoof the inferer to return our static parameters when calling `get_parameters()`
# instead of trying to sample like it normally does
spoof_static_inferer = copy.copy(inferer)
spoof_static_inferer.get_parameters = lambda: static_parameters
spoof_static_inferer.get_parameters = lambda: static_parameters # type: ignore # You shouldn't be setting a member function like this
df = self._generate_model_component_timelines(
spoof_static_inferer,
infection_timeline,
Expand Down
5 changes: 3 additions & 2 deletions src/dynode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def likelihood(
dct = self.run_simulation(tf)
solution = dct["solution"]
predicted_metrics = dct["hospitalizations"]
assert isinstance(solution, Solution)
self._checkpoint_compartment_sizes(solution)
predicted_metrics = jnp.maximum(predicted_metrics, 1e-6)
numpyro.sample(
Expand Down Expand Up @@ -259,8 +260,8 @@ def _checkpoint_compartment_sizes(self, solution: Solution):
"final_timestep_%s" % compartment.name,
solution.ys[compartment][-1],
)
for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []):
date: datetime.date
for d in getattr(self.config, "COMPARTMENT_SAVE_DATES", []):
date: datetime.date = d
date_str = date.strftime("%Y_%m_%d")
sim_day = date_to_sim_day(date, self.config.INIT_DATE)
# ensure user requests a day we actually have in `solution`
Expand Down
7 changes: 6 additions & 1 deletion src/dynode/model_odes/seip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def compute_ds(strain, immune_state, ds, di_to_w0):
),
axis=-1,
).reshape(-1, 2)
assert len(combinations.T) == 2
# compute vectorized function on all possible immune_hist x exposing strain
ds_recovered = jnp.sum(
jax.vmap(compute_ds, in_axes=(0, 0, None, None))(
*combinations.T, ds, di_to_w0
# Destructuring to tell mypy
combinations.T[0],
combinations.T[1],
ds,
di_to_w0,
),
axis=0,
)
Expand Down
Loading

0 comments on commit 3690865

Please sign in to comment.