Skip to content

Commit

Permalink
Merge remote-tracking branch 'joel/research' into feature_verification
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Feb 5, 2025
2 parents 6b7bcbd + 2e46ecc commit 8c37547
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 14 deletions.
7 changes: 5 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def __init__(
"state_std": torch.tensor(
da_state_stats.state_std.values, dtype=torch.float32
),
# Change stats below to be for diff of standardized variables
"diff_mean": torch.tensor(
da_state_stats.state_diff_mean.values, dtype=torch.float32
da_state_stats.state_diff_mean.values / da_state_stats.state_std.values,
dtype=torch.float32
),
"diff_std": torch.tensor(
da_state_stats.state_diff_std.values, dtype=torch.float32
da_state_stats.state_diff_std.values / da_state_stats.state_std.values,
dtype=torch.float32
),
}

Expand Down
47 changes: 39 additions & 8 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Third-party
import torch
from torch import nn

# Local
from .. import utils
Expand Down Expand Up @@ -59,15 +60,26 @@ def __init__(
f"{self.num_mesh_nodes} mesh)"
)

# Determine grid hidden dim
if args.hidden_dim_grid is None:
# Same as hidden_dim
hidden_dim_grid = args.hidden_dim
else:
hidden_dim_grid = args.hidden_dim_grid

# interior_dim from data + static
self.g2m_edges, g2m_dim = self.g2m_features.shape
self.m2g_edges, m2g_dim = self.m2g_features.shape

# Define sub-models
# Feature embedders for interior
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
# For grid hidden dim
self.grid_mlp_blueprint_end = [hidden_dim_grid] * (
args.hidden_layers + 1
)
self.interior_embedder = utils.make_mlp(
[self.interior_dim] + self.mlp_blueprint_end
[self.interior_dim] + self.grid_mlp_blueprint_end
)

if self.boundary_forced:
Expand All @@ -83,37 +95,50 @@ def __init__(
self.boundary_embedder = self.interior_embedder
else:
self.boundary_embedder = utils.make_mlp(
[self.boundary_dim] + self.mlp_blueprint_end
[self.boundary_dim] + self.grid_mlp_blueprint_end
)

self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)
# Projections between grid dim and hidden dim before and after processor
self.pre_mesh_proj = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim_grid, args.hidden_dim)
)
self.post_mesh_proj = nn.Sequential(
nn.SiLU(), nn.Linear(args.hidden_dim, hidden_dim_grid)
)

self.g2m_embedder = utils.make_mlp(
[g2m_dim] + self.grid_mlp_blueprint_end
)
self.m2g_embedder = utils.make_mlp(
[m2g_dim] + self.grid_mlp_blueprint_end
)

# GNNs
# encoder
self.g2m_gnn = InteractionNet(
self.g2m_edge_index,
args.hidden_dim,
hidden_dim_grid,
hidden_layers=args.hidden_layers,
update_edges=False,
num_rec=self.num_grid_connected_mesh_nodes,
)
self.encoding_grid_mlp = utils.make_mlp(
[args.hidden_dim] + self.mlp_blueprint_end
[hidden_dim_grid] + self.grid_mlp_blueprint_end
)

# decoder
self.m2g_gnn = InteractionNet(
self.m2g_edge_index,
args.hidden_dim,
hidden_dim_grid,
hidden_layers=args.hidden_layers,
update_edges=False,
num_rec=self.num_interior_nodes,
)

# Output mapping (hidden_dim -> output_dim)
self.output_map = utils.make_mlp(
[args.hidden_dim] * (args.hidden_layers + 1)
[hidden_dim_grid]
+ [hidden_dim_grid] * args.hidden_layers
+ [self.grid_output_dim],
layer_norm=False,
) # No layer norm on this one
Expand Down Expand Up @@ -432,9 +457,15 @@ def predict_step(
interior_emb
) # (B, num_interior_nodes, d_h)

# Project up mesh rep to hidden dim of graph
mesh_rep = self.pre_mesh_proj(mesh_rep)

# Run processor step
mesh_rep = self.process_step(mesh_rep)

# Project down mesh rep to hidden dim of grid
mesh_rep = self.post_mesh_proj(mesh_rep)

# Map back from mesh to grid
m2g_emb_expanded = self.expand_to_batch(m2g_emb, batch_size)
grid_rep = self.m2g_gnn(
Expand Down
6 changes: 4 additions & 2 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def __init__(

# Separate mesh node embedders for each level
self.mesh_embedders = nn.ModuleList(
[
# Bottom mesh level is first embedded to hidden dim of grid
[utils.make_mlp([mesh_dim] + self.grid_mlp_blueprint_end)]
+ [
utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
for _ in range(self.num_levels)
for _ in range(self.num_levels - 1)
]
)
self.mesh_same_embedders = nn.ModuleList(
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __init__(

# Define sub-models
# Feature embedders for mesh
self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
# Bottom mesh level is first embedded to hidden dim of grid
self.mesh_embedder = utils.make_mlp(
[mesh_dim] + self.grid_mlp_blueprint_end
)
self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end)

# GNNs
Expand Down
11 changes: 10 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ def main(input_args=None):
"--hidden_dim",
type=int,
default=64,
help="Dimensionality of all hidden representations (default: 64)",
help="Dimensionality of hidden representations (default: 64)",
)
parser.add_argument(
"--hidden_dim_grid",
type=int,
help=(
"Dimensionality of hidden representations related to grid nodes "
"(grid encodings and in grid-level MLPs)"
"(default: None, use same as hidden_dim)"
),
)
parser.add_argument(
"--hidden_layers",
Expand Down

0 comments on commit 8c37547

Please sign in to comment.