Skip to content

Commit

Permalink
Jupyterbook docs (#7)
Browse files Browse the repository at this point in the history
* adapt pyg writing for neural-lam

* fix missing and circular imports

* make regular input grid assumption explicit

* handle missing edge attributes in splitting

* minor fix for neural-lam save

* working book compilation

* fully working first jupyterbook!

* tweak intro

* add docs build+deploy github action

* cache pdm install

* cleanup
  • Loading branch information
leifdenby authored May 10, 2024
1 parent 940609b commit de68885
Show file tree
Hide file tree
Showing 21 changed files with 1,858 additions and 950 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/deploy-docs-book.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: build-docs

# Run this when the master or main branch changes
on:
push:
branches:
- master
- main
- jupyterbook-docs

# This job installs dependencies, builds the book, and pushes it to `gh-pages`
jobs:
deploy-book:
runs-on: ubuntu-latest
permissions:
pages: write
id-token: write
steps:
- uses: actions/checkout@v4

- name: Setup PDM
uses: pdm-project/setup-pdm@v4
cache: true

- name: Install dependencies
run: pdm install --group docs

# Build the book
- name: Build the book
run: |
LOGURU_LEVEL=WARNING jupyter-book build docs/
# Upload the book's HTML as an artifact
- name: Upload artifact
uses: actions/upload-pages-artifact@v2
with:
path: "_build/html"

# Deploy the book's HTML to GitHub Pages
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v2
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repos:
rev: 22.3.0
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
Expand Down
36 changes: 36 additions & 0 deletions docs/_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Book settings
# Learn more at https://jupyterbook.org/customize/config.html

title: Graphs for data-driven weather models
author: The ML LAM Community
#logo: logo.png

# Force re-execution of notebooks on each build.
# See https://jupyterbook.org/content/execute.html
execute:
execute_notebooks: force

# Define the name of the latex output file for PDF builds
latex:
latex_documents:
targetname: book.tex

# Add a bibtex file so that we can create citations
bibtex_bibfiles:
- references.bib

# Information about where the book exists on the web
repository:
url: https://github.com/mllam/weather-model-graphs
path_to_book: docs # Optional path to your book, relative to the repository root
branch: master # Which branch of the repository should be used when creating links (optional)

# Add GitHub buttons to your book
# See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository
html:
use_issues_button: true
use_repository_button: true

sphinx:
extra_extensions:
- sphinxcontrib.mermaid
9 changes: 9 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Table of contents
# Learn more at https://jupyterbook.org/customize/toc.html

format: jb-book
root: intro
chapters:
- file: background
- file: design
- file: creating_the_graph
126 changes: 126 additions & 0 deletions docs/background.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Background\n",
"\n",
"> *this section provides an introduction to what graph-based data-driven weather models are and how they work. Inline code using `weather-model-graphs` is used to demonstrate different parts of the graph*\n",
"\n",
"Current graph-based weather models use the [encode-process-decode paradigm](https://arxiv.org/abs/1806.01261) on [message-passing graphs](https://arxiv.org/abs/1704.01212) to do the auto-regressive temporal prediction of the atmospheric weather state to produce a weather forecast. \n",
"The graphs are directed acyclic graphs (DAGs) with the nodes representing features (physical variables) at a given location in space and the edges representing flow of information.\n",
"The encode-process-decode paradigm is a three-step process that involves encoding the input data into a latent space, processing the latent space to make predictions, and decoding the predictions to produce the output data. \n",
"\n",
"## The graph nodes\n",
"\n",
"Using the nomenclature of [Lam et al 2022](https://arxiv.org/abs/2212.12794) the nodes in `weather-model-graphs` are split into two types:\n",
"\n",
"- **grid nodes**: representing the physical variables of the atmospheric state at a specific `(x,y)` coordinate in the (input) initial state to the model and the (output) prediction of the model\n",
"\n",
"- **mesh nodes**: representing the latent space of the model at specific `(x,y)` coordinate in the intermediate (latent) representation of the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import weather_model_graphs as wmg\n",
"\n",
"# create some fake cartesian coordinates\n",
"def _create_fake_xy(N=10):\n",
" x = np.linspace(0.0, 1.0, N)\n",
" y = np.linspace(0.0, 1.0, N)\n",
" xy = np.stack(np.meshgrid(x, y), axis=0)\n",
" return xy\n",
"\n",
"\n",
"xy_grid = _create_fake_xy(N=10)\n",
"\n",
"graph = wmg.create.archetype.create_keisler_graph(xy_grid=xy_grid)\n",
"\n",
"# remove all edges from the graph\n",
"graph.remove_edges_from(list(graph.edges))\n",
"\n",
"ax = wmg.visualise.nx_draw_with_pos_and_attr(graph, node_color_attr=\"type\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see in the code snippet above, the only input that graph generation in `weather-model-graphs` requires are the static `(x,y)` *grid* coordinates of the atmospheric state as the state changes over time. These coordinates are used to create the **grid nodes** nodes of the graph, with a node for each `(x,y)` coordinate.\n",
"\n",
"In addition to grid nodes the graph also contains **mesh nodes** that represent the latent space of the model at a set of `(x,y)` coordinates (this is in general a different set of coordinates to the **grid nodes** coordinates)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The graph edges\n",
"\n",
"With these two sets of nodes, the graph is constructed by connecting the **grid nodes** to the **mesh nodes** and the **mesh nodes** to each other.\n",
"The edges between the **grid nodes** and the **mesh nodes** represent the encoding of the physical variables into the latent space of the model, while the edges between the **mesh nodes** represent the processing of the latent space through the time evolution of the atmospheric state.\n",
"\n",
"In summary, the complete message-passing graph consists of three components:\n",
"\n",
"- **grid-to-mesh** (`g2m`): the encoding compenent, where edges represent the encoding of physical variables into the latent space of the model\n",
"\n",
"- **mesh-to-mesh** (`m2m`): the processing component, where edges represent information flow between nodes updating the latent presentation at mesh nodes through the time evolution of the atmospheric state\n",
"\n",
"- **mesh-to-grid** (`m2g`): the decoding component, where edges represent the decoding of the latent space back into physical variables\n",
"\n",
"Practically, the **mesh-to-grid** and **grid-to-mesh** updates can probably also encode some of the time evolution processing, in addition to the latent space encoding/decoding, unless the GNN is trained specifically as an auto-encoder using the same graph as input and output."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"graph = wmg.create.archetype.create_keisler_graph(xy_grid=xy_grid)\n",
"graph_components = wmg.split_graph_by_edge_attribute(graph=graph, attr=\"component\")\n",
"\n",
"fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 5))\n",
"\n",
"for (name, graph), ax in zip(graph_components.items(), axes.flatten()):\n",
" pl_kwargs = {}\n",
" if name == \"m2m\":\n",
" pl_kwargs = dict(edge_color_attr=\"len\")\n",
" elif name == \"g2m\" or name == \"m2g\":\n",
" pl_kwargs = dict(edge_color_attr=\"len\", node_color_attr=\"type\")\n",
"\n",
" wmg.visualise.nx_draw_with_pos_and_attr(graph, ax=ax, node_size=30, **pl_kwargs)\n",
" ax.set_title(name)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit de68885

Please sign in to comment.