diff --git a/CHANGELOG.md b/CHANGELOG.md index a2fef38..2ce1af6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [2, Ny, Nx]), to allow for non-regularly gridded coordinates [\#32](https://github.com/mllam/weather-model-graphs/pull/32), @joeloskarsson +### Fixed + +- Fix crash when trying to create flat multiscale graphs with >= 3 levels + [\#41](https://github.com/mllam/weather-model-graphs/pull/41), @joeloskarsson + ## [v0.2.0](https://github.com/mllam/weather-model-graphs/releases/tag/v0.2.0) ### Added diff --git a/src/weather_model_graphs/create/mesh/kinds/flat.py b/src/weather_model_graphs/create/mesh/kinds/flat.py index ab27a08..cf14fc5 100644 --- a/src/weather_model_graphs/create/mesh/kinds/flat.py +++ b/src/weather_model_graphs/create/mesh/kinds/flat.py @@ -52,11 +52,14 @@ def create_flat_multiscale_mesh_graph( G_tot = G_all_levels[0] # First node at level l+1 share position with node (offset, offset) at level l level_offset = level_refinement_factor // 2 + + first_level_nodes = list(G_all_levels[0].nodes) + # Last nodes in first layer has pos (nx-1, ny-1) + num_nodes_x = first_level_nodes[-1][0] + 1 + num_nodes_y = first_level_nodes[-1][1] + 1 + for lev in range(1, len(G_all_levels)): nodes = list(G_all_levels[lev - 1].nodes) - # Last nodes always has pos (nx-1, ny-1) - num_nodes_x = nodes[-1][0] + 1 - num_nodes_y = nodes[-1][1] + 1 ij = ( np.array(nodes) .reshape((num_nodes_x, num_nodes_y, 2))[ @@ -72,6 +75,10 @@ def create_flat_multiscale_mesh_graph( ) G_tot = networkx.compose(G_tot, G_all_levels[lev]) + # Update number of nodes in x- and y-direction for next iteraion + num_nodes_x //= level_refinement_factor + num_nodes_y //= level_refinement_factor + # Relabel mesh nodes to start with 0 G_tot = prepend_node_index(G_tot, 0) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index ea63d1c..8d2b057 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -192,3 +192,24 @@ def test_create_lat_lon(kind): coords_crs=coords_crs, graph_crs=graph_crs, ) + + +@pytest.mark.parametrize("kind", ["graphcast", "oskarsson_hierarchical"]) +def test_create_many_levels(kind): + """Test that mesh graph creation methods that work with many levels + can handle more than 2 levels + """ + # Test 4 levels at lrf=3 + grid_coord_range = 10 + level_refinement_factor = 3 + mesh_node_distance = grid_coord_range / 3**4 + + xy = test_utils.create_fake_xy(N=grid_coord_range) + fn_name = f"create_{kind}_graph" + fn = getattr(wmg.create.archetype, fn_name) + + fn( + coords=xy, + mesh_node_distance=mesh_node_distance, + level_refinement_factor=level_refinement_factor, + )