-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adapt pyg writing for neural-lam * fix missing and circular imports * make regular input grid assumption explicit * handle missing edge attributes in splitting * minor fix for neural-lam save * working book compilation * fully working first jupyterbook! * tweak intro * add docs build+deploy github action * cache pdm install * cleanup
- Loading branch information
Showing
21 changed files
with
1,858 additions
and
950 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
name: build-docs | ||
|
||
# Run this when the master or main branch changes | ||
on: | ||
push: | ||
branches: | ||
- master | ||
- main | ||
- jupyterbook-docs | ||
|
||
# This job installs dependencies, builds the book, and pushes it to `gh-pages` | ||
jobs: | ||
deploy-book: | ||
runs-on: ubuntu-latest | ||
permissions: | ||
pages: write | ||
id-token: write | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Setup PDM | ||
uses: pdm-project/setup-pdm@v4 | ||
cache: true | ||
|
||
- name: Install dependencies | ||
run: pdm install --group docs | ||
|
||
# Build the book | ||
- name: Build the book | ||
run: | | ||
LOGURU_LEVEL=WARNING jupyter-book build docs/ | ||
# Upload the book's HTML as an artifact | ||
- name: Upload artifact | ||
uses: actions/upload-pages-artifact@v2 | ||
with: | ||
path: "_build/html" | ||
|
||
# Deploy the book's HTML to GitHub Pages | ||
- name: Deploy to GitHub Pages | ||
id: deployment | ||
uses: actions/deploy-pages@v2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Book settings | ||
# Learn more at https://jupyterbook.org/customize/config.html | ||
|
||
title: Graphs for data-driven weather models | ||
author: The ML LAM Community | ||
#logo: logo.png | ||
|
||
# Force re-execution of notebooks on each build. | ||
# See https://jupyterbook.org/content/execute.html | ||
execute: | ||
execute_notebooks: force | ||
|
||
# Define the name of the latex output file for PDF builds | ||
latex: | ||
latex_documents: | ||
targetname: book.tex | ||
|
||
# Add a bibtex file so that we can create citations | ||
bibtex_bibfiles: | ||
- references.bib | ||
|
||
# Information about where the book exists on the web | ||
repository: | ||
url: https://github.com/mllam/weather-model-graphs | ||
path_to_book: docs # Optional path to your book, relative to the repository root | ||
branch: master # Which branch of the repository should be used when creating links (optional) | ||
|
||
# Add GitHub buttons to your book | ||
# See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository | ||
html: | ||
use_issues_button: true | ||
use_repository_button: true | ||
|
||
sphinx: | ||
extra_extensions: | ||
- sphinxcontrib.mermaid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Table of contents | ||
# Learn more at https://jupyterbook.org/customize/toc.html | ||
|
||
format: jb-book | ||
root: intro | ||
chapters: | ||
- file: background | ||
- file: design | ||
- file: creating_the_graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Background\n", | ||
"\n", | ||
"> *this section provides an introduction to what graph-based data-driven weather models are and how they work. Inline code using `weather-model-graphs` is used to demonstrate different parts of the graph*\n", | ||
"\n", | ||
"Current graph-based weather models use the [encode-process-decode paradigm](https://arxiv.org/abs/1806.01261) on [message-passing graphs](https://arxiv.org/abs/1704.01212) to do the auto-regressive temporal prediction of the atmospheric weather state to produce a weather forecast. \n", | ||
"The graphs are directed acyclic graphs (DAGs) with the nodes representing features (physical variables) at a given location in space and the edges representing flow of information.\n", | ||
"The encode-process-decode paradigm is a three-step process that involves encoding the input data into a latent space, processing the latent space to make predictions, and decoding the predictions to produce the output data. \n", | ||
"\n", | ||
"## The graph nodes\n", | ||
"\n", | ||
"Using the nomenclature of [Lam et al 2022](https://arxiv.org/abs/2212.12794) the nodes in `weather-model-graphs` are split into two types:\n", | ||
"\n", | ||
"- **grid nodes**: representing the physical variables of the atmospheric state at a specific `(x,y)` coordinate in the (input) initial state to the model and the (output) prediction of the model\n", | ||
"\n", | ||
"- **mesh nodes**: representing the latent space of the model at specific `(x,y)` coordinate in the intermediate (latent) representation of the model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import weather_model_graphs as wmg\n", | ||
"\n", | ||
"# create some fake cartesian coordinates\n", | ||
"def _create_fake_xy(N=10):\n", | ||
" x = np.linspace(0.0, 1.0, N)\n", | ||
" y = np.linspace(0.0, 1.0, N)\n", | ||
" xy = np.stack(np.meshgrid(x, y), axis=0)\n", | ||
" return xy\n", | ||
"\n", | ||
"\n", | ||
"xy_grid = _create_fake_xy(N=10)\n", | ||
"\n", | ||
"graph = wmg.create.archetype.create_keisler_graph(xy_grid=xy_grid)\n", | ||
"\n", | ||
"# remove all edges from the graph\n", | ||
"graph.remove_edges_from(list(graph.edges))\n", | ||
"\n", | ||
"ax = wmg.visualise.nx_draw_with_pos_and_attr(graph, node_color_attr=\"type\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"As you can see in the code snippet above, the only input that graph generation in `weather-model-graphs` requires are the static `(x,y)` *grid* coordinates of the atmospheric state as the state changes over time. These coordinates are used to create the **grid nodes** nodes of the graph, with a node for each `(x,y)` coordinate.\n", | ||
"\n", | ||
"In addition to grid nodes the graph also contains **mesh nodes** that represent the latent space of the model at a set of `(x,y)` coordinates (this is in general a different set of coordinates to the **grid nodes** coordinates)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## The graph edges\n", | ||
"\n", | ||
"With these two sets of nodes, the graph is constructed by connecting the **grid nodes** to the **mesh nodes** and the **mesh nodes** to each other.\n", | ||
"The edges between the **grid nodes** and the **mesh nodes** represent the encoding of the physical variables into the latent space of the model, while the edges between the **mesh nodes** represent the processing of the latent space through the time evolution of the atmospheric state.\n", | ||
"\n", | ||
"In summary, the complete message-passing graph consists of three components:\n", | ||
"\n", | ||
"- **grid-to-mesh** (`g2m`): the encoding compenent, where edges represent the encoding of physical variables into the latent space of the model\n", | ||
"\n", | ||
"- **mesh-to-mesh** (`m2m`): the processing component, where edges represent information flow between nodes updating the latent presentation at mesh nodes through the time evolution of the atmospheric state\n", | ||
"\n", | ||
"- **mesh-to-grid** (`m2g`): the decoding component, where edges represent the decoding of the latent space back into physical variables\n", | ||
"\n", | ||
"Practically, the **mesh-to-grid** and **grid-to-mesh** updates can probably also encode some of the time evolution processing, in addition to the latent space encoding/decoding, unless the GNN is trained specifically as an auto-encoder using the same graph as input and output." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"graph = wmg.create.archetype.create_keisler_graph(xy_grid=xy_grid)\n", | ||
"graph_components = wmg.split_graph_by_edge_attribute(graph=graph, attr=\"component\")\n", | ||
"\n", | ||
"fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 5))\n", | ||
"\n", | ||
"for (name, graph), ax in zip(graph_components.items(), axes.flatten()):\n", | ||
" pl_kwargs = {}\n", | ||
" if name == \"m2m\":\n", | ||
" pl_kwargs = dict(edge_color_attr=\"len\")\n", | ||
" elif name == \"g2m\" or name == \"m2g\":\n", | ||
" pl_kwargs = dict(edge_color_attr=\"len\", node_color_attr=\"type\")\n", | ||
"\n", | ||
" wmg.visualise.nx_draw_with_pos_and_attr(graph, ax=ax, node_size=30, **pl_kwargs)\n", | ||
" ax.set_title(name)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.