From df40add8c8408874a14f2c84c70a098f5ac5e293 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin <39861882+arik-shurygin@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:54:31 -0700 Subject: [PATCH] Vis utils (#277) * checkpoint, adding overview plot done in plt instead of plotly * checkpoint fixing the overview plots and adding the pairwise code * updating comments to match line_length limits, adding mcmc chain plot * checkpoint, integrating vis_utils into default behavior of abstract_azure_runner * increasing fig size of overview * changing size of the correlation_pairs plot * tight bounding boxes to avoid text cutoff * adding plotly back since it is still used by the azure visualizer for now --- poetry.lock | 322 +++++++++++- pyproject.toml | 2 + .../abstract_azure_runner.py | 79 ++- src/resp_ode/__init__.py | 3 +- src/resp_ode/vis_utils.py | 462 ++++++++++++++++++ 5 files changed, 863 insertions(+), 5 deletions(-) create mode 100644 src/resp_ode/vis_utils.py diff --git a/poetry.lock b/poetry.lock index d14feebd..98f4f1ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -105,6 +105,24 @@ typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] +[[package]] +name = "asttokens" +version = "2.4.1" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, + {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] +test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] + [[package]] name = "azure-batch" version = "14.0.0" @@ -641,6 +659,23 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "comm" +version = "0.2.2" +description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +optional = false +python-versions = ">=3.8" +files = [ + {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, + {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, +] + +[package.dependencies] +traitlets = ">=4" + +[package.extras] +test = ["pytest"] + [[package]] name = "cons" version = "0.4.6" @@ -1042,6 +1077,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "executing" +version = "2.1.0" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = ">=3.8" +files = [ + {file = "executing-2.1.0-py2.py3-none-any.whl", hash = "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf"}, + {file = "executing-2.1.0.tar.gz", hash = "sha256:8ea27ddd260da8150fa5a708269c4a10e76161e2496ec3e587da9e3c0fe4b9ab"}, +] + +[package.extras] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] + [[package]] name = "fastprogress" version = "1.0.3" @@ -1421,6 +1470,63 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "ipython" +version = "8.18.0" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ipython-8.18.0-py3-none-any.whl", hash = "sha256:d538a7a98ad9b7e018926447a5f35856113a85d08fd68a165d7871ab5175f6e0"}, + {file = "ipython-8.18.0.tar.gz", hash = "sha256:4feb61210160f75e229ce932dbf8b719bff37af123c0b985fd038b14233daa16"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" + +[package.extras] +all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"] + +[[package]] +name = "ipywidgets" +version = "8.1.5" +description = "Jupyter interactive widgets" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"}, + {file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"}, +] + +[package.dependencies] +comm = ">=0.1.3" +ipython = ">=6.1.0" +jupyterlab-widgets = ">=3.0.12,<3.1.0" +traitlets = ">=4.3.1" +widgetsnbextension = ">=4.0.12,<4.1.0" + +[package.extras] +test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] + [[package]] name = "isodate" version = "0.6.1" @@ -1542,6 +1648,56 @@ files = [ [package.dependencies] typeguard = "2.13.3" +[[package]] +name = "jedi" +version = "0.19.1" +description = "An autocompletion tool for Python that can be used for text editors." +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, + {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, +] + +[package.dependencies] +parso = ">=0.8.3,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + +[[package]] +name = "jupyter-core" +version = "5.7.2" +description = "Jupyter core package. A base package on which Jupyter projects rely." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"}, + {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"}, +] + +[package.dependencies] +platformdirs = ">=2.5" +pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""} +traitlets = ">=5.3" + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] +test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] + +[[package]] +name = "jupyterlab-widgets" +version = "3.0.13" +description = "Jupyter interactive widgets for JupyterLab" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"}, + {file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -1782,6 +1938,20 @@ python-dateutil = ">=2.7" [package.extras] dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +description = "Inline Matplotlib backend for Jupyter" +optional = false +python-versions = ">=3.8" +files = [ + {file = "matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca"}, + {file = "matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90"}, +] + +[package.dependencies] +traitlets = "*" + [[package]] name = "mdit-py-plugins" version = "0.4.1" @@ -2423,6 +2593,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "parso" +version = "0.8.4" +description = "A Python Parser" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, + {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, +] + +[package.extras] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["docopt", "pytest"] + [[package]] name = "pathlib" version = "1.0.1" @@ -2434,6 +2619,20 @@ files = [ {file = "pathlib-1.0.1.tar.gz", hash = "sha256:6940718dfc3eff4258203ad5021090933e5c04707d5ca8cc9e73c94a7894ea9f"}, ] +[[package]] +name = "pexpect" +version = "4.9.0" +description = "Pexpect allows easy control of interactive console applications." +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + [[package]] name = "pillow" version = "10.4.0" @@ -2558,6 +2757,21 @@ docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx- test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] type = ["mypy (>=1.8)"] +[[package]] +name = "plotly" +version = "5.24.1" +description = "An open-source, interactive data visualization library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089"}, + {file = "plotly-5.24.1.tar.gz", hash = "sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae"}, +] + +[package.dependencies] +packaging = "*" +tenacity = ">=6.2.0" + [[package]] name = "pluggy" version = "1.5.0" @@ -2644,6 +2858,31 @@ files = [ {file = "protobuf-5.28.2.tar.gz", hash = "sha256:59379674ff119717404f7454647913787034f03fe7049cbef1d74a97bb4593f0"}, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +description = "Safely evaluate AST nodes without side effects" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0"}, + {file = "pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "pyarrow" version = "17.0.0" @@ -3186,6 +3425,27 @@ doc = ["griffe", "jupyter", "jupyter-client (<8.0.0)", "pydantic (>=2.7.4)", "qu test = ["astropy", "bokeh", "coverage", "duckdb", "faicons", "folium", "geodatasets", "geopandas", "great-tables", "holoviews", "ipyleaflet", "missingno", "palmerpenguins", "playwright (>=1.43.0)", "plotly", "plotnine", "psutil", "pytest (>=6.2.4)", "pytest-asyncio (>=0.17.2)", "pytest-cov", "pytest-playwright (>=0.3.0)", "pytest-rerunfailures", "pytest-timeout", "pytest-xdist", "ridgeplot", "rsconnect-python", "scikit-learn", "seaborn", "shinywidgets", "suntime", "syrupy", "timezonefinder", "xarray"] theme = ["libsass (>=0.23.0)"] +[[package]] +name = "shinywidgets" +version = "0.3.3" +description = "Render ipywidgets in Shiny applications" +optional = false +python-versions = ">=3.8" +files = [ + {file = "shinywidgets-0.3.3-py3-none-any.whl", hash = "sha256:e0290a6985e41eafdfbd487d6a86449d6e8692ac0d99dee307f5ad0aa1fd5a93"}, + {file = "shinywidgets-0.3.3.tar.gz", hash = "sha256:9827da55d3f57dfa0b8c5158547dde048d52a835bca1a8720655fcb0990ac008"}, +] + +[package.dependencies] +ipywidgets = ">=7.6.5" +jupyter-core = "*" +python-dateutil = ">=2.8.2" +shiny = ">=0.6.1.9005" + +[package.extras] +dev = ["black (>=23.1.0)", "flake8 (==3.9.2)", "flake8 (>=6.0.0)", "isort (>=5.11.2)", "pyright (>=1.1.284)", "wheel"] +test = ["pytest (>=6.2.4)"] + [[package]] name = "six" version = "1.16.0" @@ -3219,6 +3479,25 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "stack-data" +version = "0.6.3" +description = "Extract data from python stack frames and tracebacks for informative displays" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + [[package]] name = "starlette" version = "0.38.2" @@ -3236,6 +3515,21 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tensorflow-probability" version = "0.24.0" @@ -3349,6 +3643,21 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "traitlets" +version = "5.14.3" +description = "Traitlets Python configuration system" +optional = false +python-versions = ">=3.8" +files = [ + {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, + {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] + [[package]] name = "typeguard" version = "2.13.3" @@ -3647,6 +3956,17 @@ files = [ {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.13" +description = "Jupyter interactive widgets for Jupyter Notebook" +optional = false +python-versions = ">=3.7" +files = [ + {file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"}, + {file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"}, +] + [[package]] name = "xarray" version = "2024.9.0" @@ -3715,4 +4035,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ed91b1492ebe61f4e5ba0c87c97bcd6404e4f8bbd9255383055a97991f2f4ae5" +content-hash = "1d4666877faa10fa59ac09e5d3191d66e9cb72ebed01f1b691bb4186e953cc77" diff --git a/pyproject.toml b/pyproject.toml index f51fe4f9..7fa66fb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ mypy = "^1.10.0" requests = "^2.32.3" docker = "^7.1.0" bayeux-ml = "^0.1.14" +plotly = "^5.24.1" +shinywidgets = "^0.3.3" [build-system] diff --git a/src/mechanistic_azure/abstract_azure_runner.py b/src/mechanistic_azure/abstract_azure_runner.py index 08f36f29..698c4dc0 100644 --- a/src/mechanistic_azure/abstract_azure_runner.py +++ b/src/mechanistic_azure/abstract_azure_runner.py @@ -24,6 +24,7 @@ SEIC_Compartments, StaticValueParameters, utils, + vis_utils, ) @@ -368,13 +369,74 @@ def _save_samples(self, samples, save_path): samples[param] = samples[param].tolist() json.dump(samples, open(save_path, "w")) + def save_mcmc_chains_plot( + self, + samples: dict[str : list : np.ndarray], + save_filename: str = "mcmc_chains.png", + plot_kwargs: dict = {}, + ): + """Saves a plot mapping the MCMC chains of the inference job + + Parameters + ---------- + samples : dict[str: list | np.ndarray] + a dictionary (usually loaded from the checkpoint.json file) containing + the sampled posteriors for each chain in the shape + (num_chains, num_samples). All parameters generated with numpyro.plate + and thus have a third dimension (num_chains, num_samples, num_plates) + are flattened to the desired and displayed as + separate parameters with _i suffix for each i in num_plates. + save_filename : str, optional + filename saved under, by default "mcmc_chains.png" + plot_kwargs : dict, optional + additional keyword arguments to pass to + vis_utils.plot_mcmc_chains() + """ + fig = vis_utils.plot_mcmc_chains(samples, **plot_kwargs) + save_path = os.path.join(self.azure_output_dir, save_filename) + fig.savefig(save_path, bbox_inches="tight") + + def save_correlation_pairs_plot( + self, + samples: dict[str : list : np.ndarray], + save_filename: str = "mcmc_correlations.png", + plot_kwargs: dict = {}, + ): + """_summary_ + + Parameters + ---------- + samples : dict[str: list | np.ndarray] + a dictionary (usually loaded from the checkpoint.json file) containing + the sampled posteriors for each chain in the shape + (num_chains, num_samples). All parameters generated with numpyro.plate + and thus have a third dimension (num_chains, num_samples, num_plates) + are flattened to the desired and displayed as + separate parameters with _i suffix for each i in num_plates. + save_filename : str, optional + filename saved under, by default "mcmc_correlations.png" + plot_kwargs : dict, optional + additional keyword arguments to pass to + vis_utils.plot_checkpoint_inference_correlation_pairs(), + by default {} + """ + fig = vis_utils.plot_checkpoint_inference_correlation_pairs( + samples, **plot_kwargs + ) + save_path = os.path.join(self.azure_output_dir, save_filename) + fig.savefig(save_path, bbox_inches="tight") + def save_inference_posteriors( self, inferer: MechanisticInferer, save_filename="checkpoint.json", exclude_prefixes=["final_timestep"], + save_chains_plot=True, + save_pairs_correlation_plot=True, ) -> None: - """saves output of mcmc.get_samples(), does nothing if `inferer` has not compelted inference yet. + """saves output of mcmc.get_samples(), does nothing if `inferer` + has not compelted inference yet. By default saves accompanying + visualizations for interpretability. Parameters ---------- @@ -383,8 +445,14 @@ def save_inference_posteriors( save_filename : str, optional output filename, by default "checkpoint.json" exclude_prefixes: list[str], optional - a list of strs that, if found in a sample name, are exlcuded from the saved json. - This is common for large logging info that will bloat filesize like, by default ["final_timestep"] + a list of strs that, if found in a sample name, + are exlcuded from the saved json. This is common for large logging + info that will bloat filesize like, by default ["final_timestep"] + save_chains_plot: bool, optional + whether to save accompanying mcmc chains plot, by default True + save_pairs_correlation_plot: bool, optional + whether to save accompanying pairs correlation plot, + by default True Returns ------------ None @@ -402,6 +470,11 @@ def save_inference_posteriors( } save_path = os.path.join(self.azure_output_dir, save_filename) self._save_samples(samples, save_path) + # by default save an accompanying mcmc chains plot for readability + if save_chains_plot: + self.save_mcmc_chains_plot(samples) + if save_pairs_correlation_plot: + self.save_correlation_pairs_plot(samples) else: warnings.warn( "attempting to call `save_inference_posteriors` before inference is complete. Something is likely wrong..." diff --git a/src/resp_ode/__init__.py b/src/resp_ode/__init__.py index a1813c60..f2a4f479 100644 --- a/src/resp_ode/__init__.py +++ b/src/resp_ode/__init__.py @@ -26,7 +26,7 @@ jax.Array, ] -from . import utils +from . import utils, vis_utils # keep imports relative to avoid circular importing from .abstract_initializer import AbstractInitializer @@ -49,4 +49,5 @@ StaticValueParameters, utils, Config, + vis_utils, ] diff --git a/src/resp_ode/vis_utils.py b/src/resp_ode/vis_utils.py new file mode 100644 index 00000000..ea6424db --- /dev/null +++ b/src/resp_ode/vis_utils.py @@ -0,0 +1,462 @@ +"""A series of utility functions for generating visualizations for the model""" + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.axes import Axes +from matplotlib.colors import LinearSegmentedColormap + +from .utils import drop_keys_with_substring, flatten_list_parameters + + +def _cleanup_and_normalize_timelines( + all_state_timelines: pd.DataFrame, + plot_types: np.ndarray[str], + plot_normalizations: np.ndarray[int], + state_pop_sizes: dict[str, int], +): + # Select columns with 'float64' dtype + float_cols = list(all_state_timelines.select_dtypes(include="float64")) + # round down near-zero values to zero to make plots cleaner + all_state_timelines[float_cols] = all_state_timelines[float_cols].mask( + np.isclose(all_state_timelines[float_cols], 0, atol=1e-4), 0 + ) + for plot_type, plot_normalization in zip(plot_types, plot_normalizations): + for state_name, state_pop in state_pop_sizes.items(): + # if normalization is set to 1, we dont normalize at all. + normalization_factor = ( + plot_normalization / state_pop + if plot_normalization > 1 + else 1.0 + ) + # select all columns from that column type + cols = [ + col for col in all_state_timelines.columns if plot_type in col + ] + # update that states columns by the normalization factor + all_state_timelines.loc[ + all_state_timelines["state"] == state_name, + cols, + ] *= normalization_factor + return all_state_timelines + + +def plot_model_overview_subplot_matplotlib( + timeseries_df: pd.DataFrame, + pop_sizes: dict[str, int], + plot_types: np.ndarray[str] = np.array( + [ + "seasonality_coef", + "vaccination_", + "_external_introductions", + "_strain_proportion", + "_average_immunity", + "total_infection_incidence", # TODO MAKE AGE SPECIFIC + "pred_hosp_", + ] + ), + plot_titles: np.ndarray[str] = np.array( + [ + "Seasonality Coefficient", + "Vaccination Rate By Age", + "External Introductions by Strain (per 100k)", + "Strain Proportion of New Infections", + "Average Population Immunity Against Strains", + "Total Infection Incidence (per 100k)", + "Predicted Hospitalizations (per 100k)", + ] + ), + plot_normalizations: np.ndarray[int] = np.array( + [1, 1, 100000, 1, 1, 100000, 100000] + ), + matplotlib_style: list[str] + | str = [ + "seaborn-v0_8-colorblind", + ], +) -> plt.Figure: + """Given a dataframe resembling the azure_visualizer_timeline csv, + if it exists, returns an overview figure. The figure will contain 1 column + per state in `timeseries_df["state"]` if the column exists. The + figure will contain one row per plot_type + + Parameters + ---------- + timeseries_df : pandas.DataFrame + a dataframe containing at least the following columns: + ["date", "chain_particle", "state"] followed by columns identifying + different timeseries of interest to be plotted. + E.g. vaccination_0_17, vaccination_18_49, total_infection_incidence. + columns that share the same plot_type will be plotted on the same plot, + with their differences in the legend. + All chain_particle replicates are plotted as low + opacity lines for each plot_type + pop_sizes : dict[str, int] + population sizes of each state as a dictionary. + Keys must match the "state" column within timeseries_df + plot_types : np.ndarray[str], optional + each of the plot types to be plotted. + plot_types not found in `timeseries_df` are skipped. + columns are identified using the "in" operation, + so plot_type must be found in each of its identified columns + by default ["seasonality_coef", "vaccination_", + "_external_introductions", "_strain_proportion", "_average_immunity", + "total_infection_incidence", "pred_hosp_"] + plot_titles : np.ndarray[str], optional + titles for each plot_type as displayed on each subplot, + by default [ "Seasonality Coefficient", "Vaccination Rate By Age", + "External Introductions by Strain (per 100k)", + "Strain Proportion of New Infections", + "Average Population Immunity Against Strains", + "Total Infection Incidence (per 100k)", + "Predicted Hospitalizations (per 100k)"] + plot_normalizations : np.ndarray[int] + normalization factor for each plot type + matplotlib_style: list[str] | str + matplotlib style to plot in, by default ["seaborn-v0_8-colorblind"] + + Returns + ------- + matplotlib.pyplot.Figure + matplotlib Figure containing subplots with a column for each state + and a row for each plot_type + """ + necessary_cols = ["date", "chain_particle", "state"] + assert all( + [ + necessary_col in timeseries_df.columns + for necessary_col in necessary_cols + ] + ), ( + "missing a necessary column within timeseries_df, require %s but got %s" + % (str(necessary_cols), str(timeseries_df.columns)) + ) + num_states = len(timeseries_df["state"].unique()) + # we are counting the number of plot_types that are within timelines.columns + # this way we dont try to plot something that timelines does not have + plots_in_timelines = [ + any([plot_type in col for col in timeseries_df.columns]) + for plot_type in plot_types + ] + num_unique_plots_in_timelines = sum(plots_in_timelines) + # select only the plots we actually find within `timelines` + plot_types = plot_types[plots_in_timelines] + plot_titles = plot_titles[plots_in_timelines] + plot_normalizations = plot_normalizations[plots_in_timelines] + # normalize our dataframe by the given y axis normalization schemes + timeseries_df = _cleanup_and_normalize_timelines( + timeseries_df, + plot_types, + plot_normalizations, + pop_sizes, + ) + with plt.style.context(matplotlib_style): + fig, ax = plt.subplots( + nrows=num_unique_plots_in_timelines, + ncols=num_states, + sharex=True, + sharey="row", + squeeze=False, + figsize=(6 * num_states, 3 * num_unique_plots_in_timelines), + ) + # melt this df down to have an ID column "column" and a value column "val" + id_vars = ["date", "state", "chain_particle"] + rest = [x for x in timeseries_df.columns if x not in id_vars] + timelines_melt = pd.melt( + timeseries_df, + id_vars=["date", "state", "chain_particle"], + value_vars=rest, + var_name="column", + value_name="val", + ) + # convert to datetime if not already + timelines_melt["date"] = pd.to_datetime(timelines_melt["date"]) + + # go through each plot type, look for matching columns and plot + # that plot_type for each chain_particle pair. + for state_num, state in enumerate(timeseries_df["state"].unique()): + state_df = timelines_melt[timelines_melt["state"] == state] + print("Plotting State : " + state) + for plot_num, (plot_title, plot_type) in enumerate( + zip(plot_titles, plot_types) + ): + plot_ax = ax[plot_num][state_num] + # for example "vaccination_" in "vaccination_0_17" is true + # so we include this column in the plot under that plot_type + plot_df = state_df[[plot_type in x for x in state_df["column"]]] + columns_to_plot = plot_df["column"].unique() + # if we are plotting multiple lines, lets modify the legend to + # only display the differences between each line + if len(columns_to_plot) > 1: + plot_df.loc[:, "column"] = plot_df.loc[:, "column"].apply( + lambda x: x.replace(plot_type, "") + ) + unique_columns = plot_df["column"].unique() + # plot all chain_particles as thin transparent lines + # turn off legends since there will many lines + sns.lineplot( + plot_df, + x="date", + y="val", + hue="column", + units="chain_particle", + ax=plot_ax, + estimator=None, + alpha=0.3, + lw=0.25, + legend=False, + hue_order=unique_columns, + ) + # plot a median line of all particles with high opacity + # use this as our legend line + medians = ( + plot_df.groupby(by=["date", "column"])["val"] + .median() + .reset_index() + ) + sns.lineplot( + medians, + x="date", + y="val", + hue="column", + ax=plot_ax, + estimator=None, + alpha=1.0, + lw=2, + legend="auto", + hue_order=unique_columns, + ) + # remove y labels + plot_ax.set_ylabel("") + plot_ax.set_title(plot_title) + # make all legends except those on far right invisible + plot_ax.get_legend().set_visible(False) + # create legend for the right most plot only + if state_num == num_states - 1: + with plt.style.context(matplotlib_style): + for lh in plot_ax.get_legend().legend_handles: + lh.set_alpha(1) + plot_ax.legend( + bbox_to_anchor=(1.0, 0.5), + loc="center left", + ) + # add column titles on the top of each col for the states + for ax, state in zip(ax[0], timeseries_df["state"].unique()): + plot_title = ax.get_title() + ax.set_title(plot_title + "\n" + state) + fig.tight_layout() + + return fig + + +def plot_checkpoint_inference_correlation_pairs( + posteriors: dict[str : np.ndarray | list], + max_samples_calculated: int = 100, + matplotlib_style: list[str] + | str = [ + "seaborn-v0_8-colorblind", + ], +): + """Given a dictionary mapping a sampled parameter's name to its + posteriors samples, returns a figure plotting + the correlation of each sampled parameter with all other sampled parameters + on the upper half of the plot the correlation values, on the diagonal a + historgram of the posterior values, and on the bottom half a scatter + plot of the parameters against eachother along with a matching trend line. + + + Parameters + ---------- + posteriors: dict[str : np.ndarray | list] + a dictionary (usually loaded from the checkpoint.json file) containing + the sampled posteriors for each chain in the shape + (num_chains, num_samples). All parameters generated with numpyro.plate + and thus have a third dimension (num_chains, num_samples, num_plates) + are flattened to the desired shape and displayed as + separate parameters with _i suffix for each i in num_plates. + max_samples_calculated: int + a max cap of posterior samples per chain on which + calculations such as correlations and plotting will be performed + set for efficiency of plot generation, + set to -1 to disable cap, by default 100 + matplotlib_style: list[str] | str + matplotlib style to plot in, by default ["seaborn-v0_8-colorblind"] + + Returns + ------- + matplotlib.pyplot.Figure + Figure with `n` rows and `n` columns where + `n` is the number of sampled parameters + """ + # convert lists to np.arrays + posteriors = { + key: np.array(val) if isinstance(val, list) else val + for key, val in posteriors.items() + } + posteriors: dict[str, np.ndarray] = flatten_list_parameters(posteriors) + # drop any final_timestep parameters in case they snuck in + posteriors = drop_keys_with_substring(posteriors, "final_timestep") + number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1] + # if we are dealing with many samples per chain, + # narrow down to max_samples_calculated samples per chain + if ( + number_of_samples > max_samples_calculated + and max_samples_calculated != -1 + ): + selected_indices = np.random.choice( + number_of_samples, size=max_samples_calculated, replace=False + ) + posteriors = { + key: matrix[:, selected_indices] + for key, matrix in posteriors.items() + } + number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1] + # Flatten matrices including chains and create Correlation DataFrame + posteriors = { + key: np.array(matrix).flatten() for key, matrix in posteriors.items() + } + columns = posteriors.keys() + num_cols = len(list(columns)) + label_size = max(2, min(10, 200 / num_cols)) + # Compute the correlation matrix, reverse it so diagonal starts @ top left + samples_df = pd.DataFrame(posteriors) + correlation_df = samples_df.corr(method="pearson") + + cmap = LinearSegmentedColormap.from_list("", ["red", "grey", "blue"]) + + def _normalize_coefficients_to_0_1(r): + # squashes [-1, 1] into [0, 1] via (r - min()) / (max() - min()) + return (r + 1) / 2 + + def reg_coef(x, y, label=None, color=None, **kwargs): + ax = plt.gca() + x_name, y_name = (x.name, y.name) + r = correlation_df.loc[x_name, y_name] + ax.annotate( + "{:.2f}".format(r), + xy=(0.5, 0.5), + xycoords="axes fraction", + ha="center", + # vary size and color by the magnitude of correlation + color=cmap(_normalize_coefficients_to_0_1(r)), + size=label_size * abs(r) + label_size, + ) + ax.set_axis_off() + + def reg_plot_custom(x, y, label=None, color=None, **kwargs): + ax = plt.gca() + x_name, y_name = (x.name, y.name) + r = correlation_df.loc[x_name, y_name] + ax = sns.regplot( + x=x, + y=y, + ax=ax, + fit_reg=True, + scatter_kws={"alpha": 0.2, "s": 0.5}, + line_kws={ + "color": cmap(_normalize_coefficients_to_0_1(r)), + "linewidth": 1, + }, + ) + + # Create the plot + with plt.style.context(matplotlib_style): + g = sns.PairGrid( + data=samples_df, + vars=columns, + diag_sharey=False, + layout_pad=0.01, + ) + g.map_upper(reg_coef) + g = g.map_lower( + reg_plot_custom, + ) + g = g.map_diag(sns.histplot, kde=True) + for ax in g.axes.flatten(): + plt.setp(ax.get_xticklabels(), rotation=45, size=label_size) + plt.setp(ax.get_yticklabels(), rotation=45, size=label_size) + # extract the existing xaxis label + xlabel = ax.get_xlabel() + # set the xaxis label with rotation + ax.set_xlabel(xlabel, size=label_size, rotation=90, labelpad=4.0) + + ylabel = ax.get_ylabel() + ax.set_ylabel(ylabel, size=label_size, rotation=0, labelpad=15.0) + ax.label_outer(remove_inner_ticks=True) + # Adjust layout to make sure everything fits + px = 1 / plt.rcParams["figure.dpi"] + g.figure.set_size_inches((2000 * px, 2000 * px)) + # g.figure.tight_layout(pad=0.01, h_pad=0.01, w_pad=0.01) + return g.figure + + +def plot_mcmc_chains( + samples: dict[str : np.ndarray | list], + matplotlib_style: list[str] + | str = [ + "seaborn-v0_8-colorblind", + ], +) -> plt.Figure: + """given a `samples` dictionary containing posterior samples + often returned from numpyro.get_samples(group_by_chain=True) + or from the checkpoint.json saved file, plots each MCMC chain + for each sampled parameter in a roughly square subplot. + + Parameters + ---------- + posteriors: dict[str : np.ndarray | list] + a dictionary (usually loaded from the checkpoint.json file) containing + the sampled posteriors for each chain in the shape + (num_chains, num_samples). All parameters generated with numpyro.plate + and thus have a third dimension (num_chains, num_samples, num_plates) + are flattened to the desired and displayed as + separate parameters with _i suffix for each i in num_plates. + matplotlib_style : list[str] | str, optional + matplotlib style to plot in by default ["seaborn-v0_8-colorblind"] + + Returns + ------- + matplotlib.pyplot.Figure + matplotlib figure containing the plots + """ + # Determine the number of parameters and chains + samples = { + key: np.array(val) if isinstance(val, list) else val + for key, val in samples.items() + } + samples: dict[str, np.ndarray] = flatten_list_parameters(samples) + # drop any final_timestep parameters in case they snuck in + samples = drop_keys_with_substring(samples, "final_timestep") + param_names = list(samples.keys()) + num_params = len(param_names) + num_chains = samples[param_names[0]].shape[0] + # Calculate the number of rows and columns for a square-ish layout + num_cols = int(np.ceil(np.sqrt(num_params))) + num_rows = int(np.ceil(num_params / num_cols)) + # Create a figure with subplots + with plt.style.context(matplotlib_style): + fig, axs = plt.subplots( + num_rows, + num_cols, + figsize=(3 * num_cols, 3 * num_rows), + squeeze=False, + ) + # Flatten the axis array for easy indexing + axs_flat = axs.flatten() + # Loop over each parameter and plot its chains + for i, param_name in enumerate(param_names): + ax: Axes = axs_flat[i] + for chain in range(num_chains): + ax.plot(samples[param_name][chain], label=f"chain {chain}") + ax.set_title(param_name) + # Hide x-axis labels except for bottom plots to reduce clutter + if i < (num_params - num_cols): + ax.set_xticklabels([]) + + # Turn off any unused subplots + for j in range(i + 1, len(axs_flat)): + axs_flat[j].axis("off") + plt.tight_layout() + handles, labels = ax.get_legend_handles_labels() + fig.legend(handles, labels, loc="outside upper center") + return fig