Skip to content

Commit

Permalink
add test and example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Dec 9, 2024
1 parent ee8601e commit e178a0b
Show file tree
Hide file tree
Showing 3 changed files with 488 additions and 3 deletions.
473 changes: 473 additions & 0 deletions docs/decoding_mask.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/weather_model_graphs/visualise/plot_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def nx_draw_with_pos_and_attr(
node_zorder_attr=None,
node_size=100,
connectionstyle="arc3, rad=0.1",
with_labels=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -171,7 +172,7 @@ def nx_draw_with_pos_and_attr(
graph,
ax=ax,
arrows=True,
with_labels=False,
with_labels=with_labels,
node_size=node_size,
connectionstyle=connectionstyle,
**kwargs,
Expand Down
15 changes: 13 additions & 2 deletions tests/test_graph_decode_gridpoints_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,26 @@ def test_graph_decode_gridpoints_mask():
name_unfiltered, output_directory=tmpdirname
)

# manually filter the edge connections from
# check that the edges in the graph objects match the adjency matrices that
# have been written
graph_edges_unfiltered = np.array(unfiltered_graph.edges)
assert set(zip(*graph_edges_unfiltered.T)) == set(zip(*adj_unfiltered))
graph_edges_filtered = np.array(filtered_graph.edges)
assert set(zip(*graph_edges_filtered.T)) == set(zip(*adj_filtered))

# manually filter the edge connections to the grid-nodes that are masked
# out and create an adjency matrix
grid_indexes_to_remove = np.arange(0, xy.shape[0])[decode_mask == 0]
adj_pairs = []
for i in range(adj_unfiltered.shape[1]):
m_idx, g_idx = adj_unfiltered[:, i]
if g_idx in grid_indexes_to_remove:
continue
adj_pairs.append((m_idx, g_idx))

adj_unfiltered_masked = np.array(adj_pairs).T

import ipdb

ipdb.set_trace()

np.testing.assert_equal(adj_filtered, adj_unfiltered_masked)

0 comments on commit e178a0b

Please sign in to comment.