Skip to content

Commit

Permalink
Add flexible options to archetype graphs (#19)
Browse files Browse the repository at this point in the history
* Handle non-square areas for single-level flat graphs

* Fix tests

* Introduce radius graph connection with distance relative to longest edge

* Handle relative connection radius accurately in multiscale graphs

* Fix relative refinement factor for g2m also for hierarchical graphs

* Change multiscale archetype to enforce refinement factor 3

* Start work on creating separate refinement factors between grid-mesh and mesh levels

* Clean up separate refinement factors to also work with multiscale graph

* Only connect grid to bottom level of hierarchical mesh

* Fix m2g specification for archetypes

* Correct doscstrings to match updated archetypes

* Update docs to match archetype changes

* Fix tests

* Add test for new rel_max_dist within_radius parameter

* Fix some comments and checks

* Run pre-commit on docs

* Fix typos

Co-authored-by: sadamov <45732287+sadamov@users.noreply.github.com>

* Add explanation about why default distance is 0.51d

* Clarify docstring for split_on_edge_attribute_existance

* Clarify dimension names in mesh creation

* Clarify grid coordinates also in docstring for flat mesh graph

* Update changelog

* Clear outputs in documentation notebooks

---------

Co-authored-by: sadamov <45732287+sadamov@users.noreply.github.com>
  • Loading branch information
joeloskarsson and sadamov authored Sep 12, 2024
1 parent 92ae18b commit b468ed5
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 139 deletions.
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,28 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Vim ###
# Swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]

# Session
Session.vim
Sessionx.vim

# Temporary
.netrwhist
*~

# Auto-generated tag files
tags

# Persistent undo
[._]*.un~

# Coc configuration directory
.vim
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/mllam/weather-model-graphs/compare/v0.1.0...HEAD)

### Added

- added github pull-request template to ease contribution and review process
[\#18](https://github.com/mllam/weather-model-graphs/pull/18), @joeloskarsson

- Allow for specifying relative distance as `rel_max_dist` when connecting nodes using `within_radius` method.
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

### Changed

- Changed the `refinement_factor` argument into two: a `grid_refinement_factor` and a `level_refinement_factor`.
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- Connect grid nodes only to the bottom level of hierarchical mesh graphs.
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- Change default archetypes to match the graph creation from neural-lam.
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

## [v0.1.0](https://github.com/mllam/weather-model-graphs/releases/tag/v0.1.0)

First tagged release of `weather-model-graphs` which includes functionality to
Expand Down
10 changes: 6 additions & 4 deletions docs/background.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"graph = wmg.create.archetype.create_keisler_graph(xy_grid=xy_grid)\n",
"graph = wmg.create.archetype.create_keisler_graph(\n",
" xy_grid=xy_grid, grid_refinement_factor=2\n",
")\n",
"graph_components = wmg.split_graph_by_edge_attribute(graph=graph, attr=\"component\")\n",
"\n",
"fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 5))\n",
Expand All @@ -104,7 +106,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -118,9 +120,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
25 changes: 10 additions & 15 deletions docs/creating_the_graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Oscarsson et al 2023 hierarchical graph"
"## Oskarsson et al 2023 hierarchical graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The hierarchical graph from [Oscarsson et al 2023](https://arxiv.org/abs/2202.07575) builds on the GraphCast graph by adding a hierarchical structure to the mesh component of the graph. This allows the model to capture both short and long-range spatial interactions and to learn the spatial hierarchy of the data. The message-passing on different levels of interaction length-scales are learnt separately (rather than in a single pass) which allows the model to learn the spatial hierarchy of the data."
"The hierarchical graph from [Oskarsson et al 2023](https://arxiv.org/abs/2202.07575) builds on the GraphCast graph by adding a hierarchical structure to the mesh component of the graph. This allows the model to capture both short and long-range spatial interactions and to learn the spatial hierarchy of the data. The message-passing on different levels of interaction length-scales are learnt separately (rather than in a single pass) which allows the model to learn the spatial hierarchy of the data."
]
},
{
Expand All @@ -266,7 +266,7 @@
"metadata": {},
"outputs": [],
"source": [
"?wmg.create.archetype.create_oscarsson_hierarchical_graph"
"?wmg.create.archetype.create_oskarsson_hierarchical_graph"
]
},
{
Expand All @@ -275,7 +275,7 @@
"metadata": {},
"outputs": [],
"source": [
"graph = wmg.create.archetype.create_oscarsson_hierarchical_graph(xy_grid=xy)\n",
"graph = wmg.create.archetype.create_oskarsson_hierarchical_graph(xy_grid=xy)\n",
"graph"
]
},
Expand Down Expand Up @@ -411,7 +411,9 @@
"graph = wmg.create.create_all_graph_components(\n",
" m2m_connectivity=\"flat_multiscale\",\n",
" xy=xy,\n",
" m2m_connectivity_kwargs=dict(refinement_factor=3, max_num_levels=None),\n",
" m2m_connectivity_kwargs=dict(\n",
" grid_refinement_factor=2, level_refinement_factor=3, max_num_levels=None\n",
" ),\n",
" g2m_connectivity=\"nearest_neighbour\",\n",
" m2g_connectivity=\"nearest_neighbour\",\n",
")\n",
Expand All @@ -438,18 +440,11 @@
" ax.set_title(name)\n",
" ax.set_aspect(1.0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -463,9 +458,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
106 changes: 69 additions & 37 deletions src/weather_model_graphs/create/archetype.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from .base import create_all_graph_components


def create_keisler_graph(xy_grid):
def create_keisler_graph(xy_grid, grid_refinement_factor=3):
"""
Create a graph following Keisler (2022, https://arxiv.org/abs/2202.07575) architecture.
Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global graph used by Keisler (2022, https://arxiv.org/abs/2202.07575).
This graph is a flat multiscale graph with nearest neighbour connectivity
(8 neighbours) within the mesh. The grid to mesh connectivity connects each mesh node to
the four nearest grid points. The mesh to grid connectivity connects each grid point to the
nearest mesh node.
This graph is a flat single scale graph with nearest neighbour connectivity
(8 neighbours) within the mesh.
TODO: Verify that Keisler does in fact use these g2m and m2g connectivities.
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.
Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
merge_components: bool
Whether to merge the components of the graph.
grid_refinement_factor: float
Refinement factor between grid points and mesh
Returns
-------
Expand All @@ -27,34 +31,43 @@ def create_keisler_graph(xy_grid):
return create_all_graph_components(
xy=xy_grid,
m2m_connectivity="flat",
m2m_connectivity_kwargs={},
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
m2m_connectivity_kwargs=dict(grid_refinement_factor=grid_refinement_factor),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)


def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None):
def create_graphcast_graph(
xy_grid, grid_refinement_factor=3, level_refinement_factor=3, max_num_levels=None
):
"""
Create a graph following the Lam et al (2023, https://arxiv.org/abs/2212.12794) GraphCast architecture.
Create a multiscale LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global GraphCast graph used by Lam et al (2023, https://arxiv.org/abs/2212.12794)
This graph is a flat multiscale graph with nearest neighbour connectivity (4 neighbours) with both nearest
neighbour and longer range connections in the mesh, using the `refinement_factor` and `max_num_levels` parameters
to constrain the range-length of the connections. The grid to mesh connectivity connects each mesh node to
to its nearest 4 grid points. The mesh to grid connectivity connects each grid point to the nearest mesh node.
This graph is a flat multiscale graph with neighbour connectivity and longer multi-scale edges.
TODO: Verify that GraphCast does in fact use these g2m and m2g connectivities.
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.
Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
refinement_factor: int
Refinement factor for longer-range connections in the mesh graph, the
reduction factor in the number of mesh points between levels (in both
x and y directions).
grid_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
level_refinement_factor: int
Refinement factor between grid points and bottom level of mesh hierarchy
NOTE: Must be an odd integer >1 to create proper multiscale graph
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.
Expand All @@ -67,39 +80,51 @@ def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None):
xy=xy_grid,
m2m_connectivity="flat_multiscale",
m2m_connectivity_kwargs=dict(
refinement_factor=refinement_factor, max_num_levels=max_num_levels
grid_refinement_factor=grid_refinement_factor,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
),
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)


def create_oscarsson_hierarchical_graph(xy_grid):
def create_oskarsson_hierarchical_graph(
xy_grid, grid_refinement_factor=3, level_refinement_factor=3, max_num_levels=None
):
"""
Create a graph following Oscarsson et al (2023, https://arxiv.org/abs/2309.17370)
Create a LAM graph following Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
hierarchical architecture.
The mesh graph in this architecture is hierarchical in that each refinement of
longer-range edges are split into different levels. In addition to these same-level
connections the mesh graph contains nearest neighbour connections between
levels (up and down). To distinguish between these these three types of
edge connections each edge has a `direction` attribute (with value "up",
"down", or "same"). In addition the `level` attribute indicates which two levels
"down", or "same"). In addition, the `levels` attribute indicates which two levels
are connected for cross-level edges (e.g. "1>2" for edges between level 1 and 2).
The grid to mesh connectivity connects each mesh node to the four nearest
grid points, and the mesh to grid connectivity connects each grid point to
the nearest mesh node.
TODO: Is this the right connectivity for the g2m and m2g components?
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.
Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
grid_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
level_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
Returns
-------
Expand All @@ -109,10 +134,17 @@ def create_oscarsson_hierarchical_graph(xy_grid):
return create_all_graph_components(
xy=xy_grid,
m2m_connectivity="hierarchical",
m2m_connectivity_kwargs=dict(refinement_factor=2, max_num_levels=3),
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
m2m_connectivity_kwargs=dict(
grid_refinement_factor=grid_refinement_factor,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)
Loading

0 comments on commit b468ed5

Please sign in to comment.