Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-node-training #103

Merged
merged 8 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased](https://github.com/mllam/neural-lam/compare/v0.3.0...HEAD)

### Added
- Add support for multi-node training.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

### Fixed
- Only print on rank 0 to avoid duplicates of all print statements.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

## [v0.3.0](https://github.com/mllam/neural-lam/releases/tag/v0.3.0)

This release introduces Datastores to represent input data from different sources (including zarr and numpy) while keeping graph generation within neural-lam.
Expand Down
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,36 @@ python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ..

Checkpoint files for our models trained on the MEPS data are available upon request.

### High Performance Computing

The training script can be run on a cluster with multiple GPU-nodes. Neural LAM is set up to use PyTorch Lightning's `DDP` backend for distributed training.
The code can be used on systems both with and without slurm. If the cluster has multiple nodes, set the `--num_nodes` argument accordingly.

Using SLURM, the job can be started with `sbatch slurm_job.sh` with a shell script like the following.
```
#!/bin/bash -l
#SBATCH --job-name=Neural-LAM
#SBATCH --time=24:00:00
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres:gpu=4
#SBATCH --partition=normal
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurallam_out_%j.log
#SBATCH --error=lightning_logs/neurallam_err_%j.log

# Load necessary modules or activate environment, for example:
conda activate neural-lam

srun -ul python -m neural_lam.train_model \
--config_path /path/to/config.yaml \
--num_nodes $SLURM_JOB_NUM_NODES
```

When using on a system without SLURM, where all GPU's are visible, it is possible to select a subset of GPU's to use for training with the `devices` cli argument, e.g. `--devices 0 1` to use the first 2 GPU's.

## Evaluate Models
Evaluation is also done using `python -m neural_lam.train_model --config_path <config-path>`, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
Expand Down
9 changes: 5 additions & 4 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy import ndarray

# Local
from ..utils import rank_zero_print
from .base import BaseRegularGridDatastore, CartesianGridShape


Expand Down Expand Up @@ -72,11 +73,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("The loaded datastore contains the following features:")
rank_zero_print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
var_names = self.get_vars_names(category)
print(f" {category:<8s}: {' '.join(var_names)}")
rank_zero_print(f" {category:<8s}: {' '.join(var_names)}")

# check that all three train/val/test splits are available
required_splits = ["train", "val", "test"]
Expand All @@ -87,12 +88,12 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
f"splits: {available_splits}"
)

print("With the following splits (over time):")
rank_zero_print("With the following splits (over time):")
for split in required_splits:
da_split = self._ds.splits.sel(split_name=split)
da_split_start = da_split.sel(split_part="start").load().item()
da_split_end = da_split.sel(split_part="end").load().item()
print(f" {split:<8s}: {da_split_start} to {da_split_end}")
rank_zero_print(f" {split:<8s}: {da_split_start} to {da_split_end}")

# find out the dimension order for the stacking to grid-index
dim_order = None
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):

# Specify dimensions of data
self.num_mesh_nodes, _ = self.get_num_mesh()
print(
utils.rank_zero_print(
f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} "
f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)"
)
Expand Down
10 changes: 6 additions & 4 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
] # Needs as python list for later

# Print some useful info
print("Loaded hierarchical graph with structure:")
utils.rank_zero_print("Loaded hierarchical graph with structure:")
for level_index, level_mesh_size in enumerate(self.level_mesh_sizes):
same_level_edges = self.m2m_features[level_index].shape[0]
print(
utils.rank_zero_print(
f"level {level_index} - {level_mesh_size} nodes, "
f"{same_level_edges} same-level edges"
)

if level_index < (self.num_levels - 1):
up_edges = self.mesh_up_features[level_index].shape[0]
down_edges = self.mesh_down_features[level_index].shape[0]
print(f" {level_index}<->{level_index + 1}")
print(f" - {up_edges} up edges, {down_edges} down edges")
utils.rank_zero_print(f" {level_index}<->{level_index + 1}")
utils.rank_zero_print(
f" - {up_edges} up edges, {down_edges} down edges"
)
# Embedders
# Assume all levels have same static feature dimensionality
mesh_dim = self.mesh_static_features[0].shape[1]
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
# grid_dim from data + static + batch_static
mesh_dim = self.mesh_static_features.shape[1]
m2m_edges, m2m_dim = self.m2m_features.shape
print(
utils.rank_zero_print(
f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, "
f"m2g={self.m2g_edges}"
)
Expand Down
27 changes: 27 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def main(input_args=None):
default=4,
help="Number of workers in data loader (default: 4)",
)
parser.add_argument(
"--num_nodes",
type=int,
default=1,
help="Number of nodes to use in DDP (default: 1)",
)
parser.add_argument(
"--devices",
nargs="+",
type=str,
default=["auto"],
help="Devices to use for training. Can be the string 'auto' or a list "
"of integer id's corresponding to the desired devices, e.g. "
"'--devices 0 1'. Note that this cannot be used with SLURM, instead "
"set 'ntasks-per-node' in the slurm setup (default: auto)",
)
parser.add_argument(
"--epochs",
type=int,
Expand Down Expand Up @@ -249,6 +265,15 @@ def main(input_args=None):
else:
device_name = "cpu"

# Set devices to use
if args.devices == ["auto"]:
devices = "auto"
else:
try:
devices = [int(i) for i in args.devices]
except ValueError:
raise ValueError("devices should be 'auto' or a list of integers")

# Load model parameters Use new args for model
ModelClass = MODELS[args.model]
model = ModelClass(args, config=config, datastore=datastore)
Expand Down Expand Up @@ -278,6 +303,8 @@ def main(input_args=None):
deterministic=True,
strategy="ddp",
accelerator=device_name,
num_nodes=args.num_nodes,
devices=devices,
logger=logger,
log_every_n_steps=1,
callbacks=[checkpoint_callback],
Expand Down
7 changes: 7 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Third-party
import torch
from pytorch_lightning.utilities import rank_zero_only
from torch import nn
from tueplots import bundles, figsizes

Expand Down Expand Up @@ -233,6 +234,12 @@ def fractional_plot_bundle(fraction):
return bundle


@rank_zero_only
def rank_zero_print(*args, **kwargs):
"""Print only from rank 0 process"""
print(*args, **kwargs)


def init_wandb_metrics(wandb_logger, val_steps):
"""
Set up wandb metrics to track
Expand Down
Loading