Skip to content

Commit

Permalink
Create different number of mesh nodes in x- and y-direction (#21)
Browse files Browse the repository at this point in the history
* Implement separate refinement in x and y for multirange 2d mesh

* Change to compute based on coordinate array shape rather than coordinate values

* Change flattening of multiscale graphs to handle different x and y dims

* Add test for non-square graph craetion

* Add changelog entry
  • Loading branch information
joeloskarsson authored Sep 18, 2024
1 parent b468ed5 commit 6406833
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Create different number of mesh nodes in x- and y-direction.
[\#21](https://github.com/mllam/weather-model-graphs/pull/21)
@joeloskarsson

- Changed the `refinement_factor` argument into two: a `grid_refinement_factor` and a `level_refinement_factor`.
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson
Expand Down
8 changes: 5 additions & 3 deletions src/weather_model_graphs/create/mesh/kinds/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,17 @@ def create_flat_multiscale_mesh_graph(
level_offset = level_refinement_factor // 2
for lev in range(1, len(G_all_levels)):
nodes = list(G_all_levels[lev - 1].nodes)
n = int(np.sqrt(len(nodes)))
# Last nodes always has pos (nx-1, ny-1)
num_nodes_x = nodes[-1][0] + 1
num_nodes_y = nodes[-1][1] + 1
ij = (
np.array(nodes)
.reshape((n, n, 2))[
.reshape((num_nodes_x, num_nodes_y, 2))[
level_offset::level_refinement_factor,
level_offset::level_refinement_factor,
:,
]
.reshape(int(n / level_refinement_factor) ** 2, 2)
.reshape(int(num_nodes_x * num_nodes_y / (level_refinement_factor**2)), 2)
)
ij = [tuple(x) for x in ij]
G_all_levels[lev] = networkx.relabel_nodes(
Expand Down
40 changes: 28 additions & 12 deletions src/weather_model_graphs/create/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,37 +94,53 @@ def create_multirange_2d_mesh_graphs(
max_num_levels : int
Number of edge-distance levels in mesh graph
xy : np.ndarray
Grid point coordinates
Grid point coordinates, shaped [2, M, N]
refinement_factor : int
Degree of refinement between successive mesh graphs, the number of nodes
grows by refinement_factor**2 between successive mesh graphs
grows by approximately refinement_factor**2 between successive
mesh graphs.
Returns
-------
G_all_levels : list of networkx.Graph
List of networkx graphs for each level representing the connectivity
of the mesh within each level
"""
# Compute the size (grid nodes) along x and y direction of area
# to cover with graph
coord_extent = np.array((xy.shape[2], xy.shape[1]))

# Find the number of mesh levels possible in x- and y-direction,
# and the number of leaf nodes that would correspond to
# max_coord/(grid_refinement_factor*level_refinement_factor^mesh_levels) = 1
mesh_levels = int(
(np.log(max(xy.shape)) - np.log(grid_refinement_factor))
max_mesh_levels = (
(np.log(coord_extent) - np.log(grid_refinement_factor))
/ np.log(level_refinement_factor)
)
).astype(
int
) # (2,)
nleaf = grid_refinement_factor * (
level_refinement_factor**mesh_levels
) # leaves at the bottom = nleaf**2
level_refinement_factor**max_mesh_levels
) # leaves at the bottom in each direction, if using max_mesh_levels

# As we can not instantiate different number of mesh levels in each
# direction, create mesh levels corresponding to the minimum of the two
mesh_levels_to_create = max_mesh_levels.min()

if max_num_levels:
# Limit the levels in mesh graph
mesh_levels = min(mesh_levels, max_num_levels)
mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels)

logger.debug(f"mesh_levels: {mesh_levels}, nleaf: {nleaf}")
logger.debug(f"mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}")

# multi resolution tree levels
G_all_levels = []
for lev in range(mesh_levels): # 0-index mesh levels
n = int(nleaf / (grid_refinement_factor * (level_refinement_factor**lev)))
g = create_single_level_2d_mesh_graph(xy, n, n)
for lev in range(mesh_levels_to_create): # 0-index mesh levels
# Compute number of nodes on level separate for each direction
nodes_x, nodes_y = (
nleaf / (grid_refinement_factor * (level_refinement_factor**lev))
).astype(int)
g = create_single_level_2d_mesh_graph(xy, nodes_x, nodes_y)
# Add level information to nodes, edges and full graph
for node in g.nodes:
g.nodes[node]["level"] = lev
Expand Down
26 changes: 26 additions & 0 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def _create_fake_xy(N=10):
return xy


def _create_rectangular_fake_xy(Nx=10, Ny=20):
x = np.linspace(0, 1, Nx)
y = np.linspace(0, 1, Ny)
xy = np.meshgrid(x, y)
xy = np.stack(xy, axis=0)
return xy


def test_create_single_level_mesh_graph():
xy = _create_fake_xy(N=4)
mesh_graph = wmg.create.mesh.create_single_level_2d_mesh_graph(xy=xy, nx=5, ny=5)
Expand Down Expand Up @@ -94,3 +102,21 @@ def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivi
isinstance(graph, nx.DiGraph) for graph in graph_components.values()
)
assert set(graph_components.keys()) == {"m2m", "m2g", "g2m"}


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
def test_create_rectangular_graph(kind):
"""
Tests that graphs can be created for non-square areas, both thin and wide
"""
# Test thin
xy = _create_rectangular_fake_xy(Nx=20, Ny=64)
fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)
fn(xy_grid=xy, grid_refinement_factor=2)

# Test wide
xy = _create_rectangular_fake_xy(Nx=64, Ny=20)
fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)
fn(xy_grid=xy, grid_refinement_factor=2)

0 comments on commit 6406833

Please sign in to comment.