diff --git a/.github/workflows/deploy-docs-book.yml b/.github/workflows/deploy-docs-book.yml index 8ae3a71..7754d64 100644 --- a/.github/workflows/deploy-docs-book.yml +++ b/.github/workflows/deploy-docs-book.yml @@ -24,7 +24,9 @@ jobs: cache: true - name: Install dependencies - run: pdm install --group docs + run: | + pdm install --prod + pdm install --group docs # Build the book - name: Build the book diff --git a/pdm.lock b/pdm.lock index 379f073..de8e5d6 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "pytorch", "visualisation", "docs"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:da1fa5fcd903126b6fd3e2981f8054b997a0b3ad306bbc8ef184ec8a354f15ff" +content_hash = "sha256:6f7af1faedd38411421502d75959265369f6d555760d866d17ed9adda9cee7e6" [[package]] name = "accessible-pygments" @@ -2427,6 +2427,17 @@ files = [ {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"}, ] +[[package]] +name = "sphinxcontrib-mermaid" +version = "0.9.2" +requires_python = ">=3.7" +summary = "Mermaid diagrams in yours Sphinx powered docs" +groups = ["docs"] +files = [ + {file = "sphinxcontrib-mermaid-0.9.2.tar.gz", hash = "sha256:252ef13dd23164b28f16d8b0205cf184b9d8e2b714a302274d9f59eb708e77af"}, + {file = "sphinxcontrib_mermaid-0.9.2-py3-none-any.whl", hash = "sha256:6795a72037ca55e65663d2a2c1a043d636dc3d30d418e56dd6087d1459d98a5d"}, +] + [[package]] name = "sphinxcontrib-qthelp" version = "1.0.7" diff --git a/pyproject.toml b/pyproject.toml index c2d2d92..f14bb15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ visualisation = [ ] docs = [ "jupyter-book>=1.0.0", + "sphinxcontrib-mermaid>=0.9.2", ] [build-system] diff --git a/src/weather_model_graphs/networkx_utils.py b/src/weather_model_graphs/networkx_utils.py index 59792f8..93101a5 100644 --- a/src/weather_model_graphs/networkx_utils.py +++ b/src/weather_model_graphs/networkx_utils.py @@ -103,6 +103,17 @@ def split_graph_by_edge_attribute(graph, attr): f"No subgraphs were created. Check the edge attribute '{attr}'." ) + # copy node attributes + for subgraph in subgraphs.values(): + for node in subgraph.nodes: + subgraph.nodes[node].update(graph.nodes[node]) + + # check that at least one subgraph was created + if len(subgraphs) == 0: + raise ValueError( + f"No subgraphs were created. Check the edge attribute '{attr}'." + ) + return subgraphs