From 4a97a1209e8fceadef8162f13db9e4805681d22e Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Wed, 22 May 2024 10:22:16 +0200 Subject: [PATCH] Replace constants.py with data_config.yaml (#31) **Summary** This PR replaces the `constants.py` file with a `data_config.yaml` file. Dataset related settings can be defined by the user in the new yaml file. Training specific settings were added as additional flags to the `train_model.py` routine. All respective calls to the old files were replaced. **Rationale** - Using a Yaml file for data config gives much more flexibility for various datasets used in the community. It also facilitates the future use of forcing and boundary datasets. In a follow-up PR the dataset paths will be defined in the yaml file, removing the dependency on a pre-structured `/data` folder. - It is best practice to define user input in a yaml file, the usage of python scripts for that purpose is not common. - The old `constants.py` actually combined both constants and variables, many "constants" should rather be flags to `train_models.py` - The introduction of a new ConfigClass in `utils.py` allows for very specific queries of the yaml and calculations based thereon. This branch shows future possibilities of such a class https://github.com/joeloskarsson/neural-lam/tree/feature_dataset_yaml **Testing** Both training and evaluation of the model were succesfully tested with the `meps_example` dataset. **Note** @leifdenby Could you invite Thomas R. to this repo, in case he wanted to give his input on the yaml file? This PR should mostly serve as a basis for discussion. Maybe we should add more information to the yaml file as you outline in https://github.com/mllam/mllam-data-prep. I think we should always keep in mind how the repository will look like with realistic boundary conditions and zarr-archives as data-input. This PR solves parts of https://github.com/joeloskarsson/neural-lam/issues/23 --------- Co-authored-by: Simon Adamov --- .gitignore | 1 + CHANGELOG.md | 16 ++++- README.md | 5 +- create_grid_features.py | 12 ++-- create_mesh.py | 13 ++-- create_parameter_weights.py | 20 +++--- neural_lam/config.py | 62 ++++++++++++++++++ neural_lam/constants.py | 120 ---------------------------------- neural_lam/data_config.yaml | 64 ++++++++++++++++++ neural_lam/models/ar_model.py | 68 ++++++++++--------- neural_lam/utils.py | 7 +- neural_lam/vis.py | 30 +++++---- neural_lam/weather_dataset.py | 6 +- plot_graph.py | 11 ++-- train_model.py | 51 +++++++++++---- 15 files changed, 274 insertions(+), 212 deletions(-) create mode 100644 neural_lam/config.py delete mode 100644 neural_lam/constants.py create mode 100644 neural_lam/data_config.yaml diff --git a/.gitignore b/.gitignore index 7bb826a2..c9d914c2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ graphs *.sif sweeps test_*.sh +.vscode ### Python ### # Byte-compiled / optimized / DLL files diff --git a/CHANGELOG.md b/CHANGELOG.md index 19ecdd41..823ac8b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) @joeloskarsson @@ -24,6 +27,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + - moved batch-static features ("water cover") into forcing component return by `WeatherDataset` [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson @@ -44,8 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) @joeloskarsson - ## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication -(https://arxiv.org/abs/2309.17370) +() diff --git a/README.md b/README.md index 67d9d9b1..ba0bb3fe 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Still, some restrictions are inevitable: ## A note on the limited area setting Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). +This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. @@ -104,13 +104,12 @@ The graph-related files are stored in a directory called `graphs`. ### Create remaining static features To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: diff --git a/create_grid_features.py b/create_grid_features.py index c9038103..c3714368 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -6,6 +6,9 @@ import numpy as np import torch +# First-party +from neural_lam import config + def main(): """ @@ -13,14 +16,15 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) - static_dir_path = os.path.join("data", args.dataset, "static") + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # -- Static grid node features -- grid_xy = torch.tensor( diff --git a/create_mesh.py b/create_mesh.py index cb524cd6..f04b4d4b 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -12,6 +12,9 @@ import torch_geometric as pyg from torch_geometric.utils.convert import from_networkx +# First-party +from neural_lam import config + def plot_graph(graph, title=None): fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H @@ -153,11 +156,10 @@ def prepend_node_index(graph, new_index): def main(): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to load grid point coordinates from " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -187,7 +189,8 @@ def main(): args = parser.parse_args() # Load grid positions - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") graph_dir_path = os.path.join("graphs", args.graph) os.makedirs(graph_dir_path, exist_ok=True) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index 494a5e81..cae1ae3e 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -8,7 +8,7 @@ from tqdm import tqdm # First-party -from neural_lam import constants +from neural_lam import config from neural_lam.weather_dataset import WeatherDataset @@ -18,10 +18,10 @@ def main(): """ parser = ArgumentParser(description="Training arguments") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset to compute weights for (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--batch_size", @@ -43,7 +43,8 @@ def main(): ) args = parser.parse_args() - static_dir_path = os.path.join("data", args.dataset, "static") + config_loader = config.Config.from_file(args.data_config) + static_dir_path = os.path.join("data", config_loader.dataset.name, "static") # Create parameter weights based on height # based on fig A.1 in graph cast paper @@ -56,7 +57,10 @@ def main(): "500": 0.03, } w_list = np.array( - [w_dict[par.split("_")[-2]] for par in constants.PARAM_NAMES] + [ + w_dict[par.split("_")[-2]] + for par in config_loader.dataset.var_longnames + ] ) print("Saving parameter weights...") np.save( @@ -66,7 +70,7 @@ def main(): # Load dataset without any subsampling ds = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, @@ -113,7 +117,7 @@ def main(): # Compute mean and std.-dev. of one-step differences across the dataset print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( - args.dataset, + config_loader.dataset.name, split="train", subsample_step=1, pred_length=63, diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..5891ea74 --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,62 @@ +# Standard library +import functools +from pathlib import Path + +# Third-party +import cartopy.crs as ccrs +import yaml + + +class Config: + """ + Class for loading configuration files. + + This class loads a configuration file and provides a way to access its + values as attributes. + """ + + def __init__(self, values): + self.values = values + + @classmethod + def from_file(cls, filepath): + """Load a configuration file.""" + if filepath.endswith(".yaml"): + with open(filepath, encoding="utf-8", mode="r") as file: + return cls(values=yaml.safe_load(file)) + else: + raise NotImplementedError(Path(filepath).suffix) + + def __getattr__(self, name): + keys = name.split(".") + value = self.values + for key in keys: + if key in value: + value = value[key] + else: + return None + if isinstance(value, dict): + return Config(values=value) + return value + + def __getitem__(self, key): + value = self.values[key] + if isinstance(value, dict): + return Config(values=value) + return value + + def __contains__(self, key): + return key in self.values + + def num_data_vars(self): + """Return the number of data variables for a given key.""" + return len(self.dataset.var_names) + + @functools.cached_property + def coords_projection(self): + """Return the projection.""" + proj_config = self.values["projection"] + proj_class_name = proj_config["class"] + proj_class = getattr(ccrs, proj_class_name) + proj_params = proj_config.get("kwargs", {}) + return proj_class(**proj_params) diff --git a/neural_lam/constants.py b/neural_lam/constants.py deleted file mode 100644 index 527c31d8..00000000 --- a/neural_lam/constants.py +++ /dev/null @@ -1,120 +0,0 @@ -# Third-party -import cartopy -import numpy as np - -WANDB_PROJECT = "neural-lam" - -SECONDS_IN_YEAR = ( - 365 * 24 * 60 * 60 -) # Assuming no leap years in dataset (2024 is next) - -# Log prediction error for these lead times -VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) - -# Log these metrics to wandb as scalar values for -# specific variables and lead times -# List of metrics to watch, including any prefix (e.g. val_rmse) -METRICS_WATCH = [] -# Dict with variables and lead times to log watched metrics for -# Format is a dictionary that maps from a variable index to -# a list of lead time steps -VAR_LEADS_METRICS_WATCH = { - 6: [2, 19], # t_2 - 14: [2, 19], # wvint_0 - 15: [2, 19], # z_1000 -} - -# Variable names -PARAM_NAMES = [ - "pres_heightAboveGround_0_instant", - "pres_heightAboveSea_0_instant", - "nlwrs_heightAboveGround_0_accum", - "nswrs_heightAboveGround_0_accum", - "r_heightAboveGround_2_instant", - "r_hybrid_65_instant", - "t_heightAboveGround_2_instant", - "t_hybrid_65_instant", - "t_isobaricInhPa_500_instant", - "t_isobaricInhPa_850_instant", - "u_hybrid_65_instant", - "u_isobaricInhPa_850_instant", - "v_hybrid_65_instant", - "v_isobaricInhPa_850_instant", - "wvint_entireAtmosphere_0_instant", - "z_isobaricInhPa_1000_instant", - "z_isobaricInhPa_500_instant", -] - -PARAM_NAMES_SHORT = [ - "pres_0g", - "pres_0s", - "nlwrs_0", - "nswrs_0", - "r_2", - "r_65", - "t_2", - "t_65", - "t_500", - "t_850", - "u_65", - "u_850", - "v_65", - "v_850", - "wvint_0", - "z_1000", - "z_500", -] -PARAM_UNITS = [ - "Pa", - "Pa", - "W/m\\textsuperscript{2}", - "W/m\\textsuperscript{2}", - "-", # unitless - "-", - "K", - "K", - "K", - "K", - "m/s", - "m/s", - "m/s", - "m/s", - "kg/m\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", - "m\\textsuperscript{2}/s\\textsuperscript{2}", -] - -# Projection and grid -# Hard coded for now, but should eventually be part of dataset desc. files -GRID_SHAPE = (268, 238) # (y, x) - -LAMBERT_PROJ_PARAMS = { - "a": 6367470, - "b": 6367470, - "lat_0": 63.3, - "lat_1": 63.3, - "lat_2": 63.3, - "lon_0": 15.0, - "proj": "lcc", -} - -GRID_LIMITS = [ # In projection - -1059506.5523409774, # min x - 1310493.4476590226, # max x - -1331732.4471934352, # min y - 1338267.5528065648, # max y -] - -# Create projection -LAMBERT_PROJ = cartopy.crs.LambertConformal( - central_longitude=LAMBERT_PROJ_PARAMS["lon_0"], - central_latitude=LAMBERT_PROJ_PARAMS["lat_0"], - standard_parallels=( - LAMBERT_PROJ_PARAMS["lat_1"], - LAMBERT_PROJ_PARAMS["lat_2"], - ), -) - -# Data dimensions -GRID_FORCING_DIM = 5 * 3 + 1 # 5 feat. for 3 time-step window + 1 batch-static -GRID_STATE_DIM = 17 diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml new file mode 100644 index 00000000..f16a4a30 --- /dev/null +++ b/neural_lam/data_config.yaml @@ -0,0 +1,64 @@ +dataset: + name: meps_example + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + var_units: + - Pa + - Pa + - r"$\mathrm{W}/\mathrm{m}^2$" + - r"$\mathrm{W}/\mathrm{m}^2$" + - "" + - "" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s + - r"$\mathrm{kg}/\mathrm{m}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + - r"$\mathrm{m}^2/\mathrm{s}^2$" + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + num_forcing_features: 16 +grid_shape_state: [268, 238] +projection: + class: LambertConformal + kwargs: + central_longitude: 15.0 + central_latitude: 63.3 + standard_parallels: [63.3, 63.3] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d0a8320..9cda9fc2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,7 +9,7 @@ import wandb # First-party -from neural_lam import constants, metrics, utils, vis +from neural_lam import config, metrics, utils, vis class ARModel(pl.LightningModule): @@ -24,10 +24,13 @@ class ARModel(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters() - self.lr = args.lr + self.args = args + self.config_loader = config.Config.from_file(args.data_config) # Load static features for grid/data - static_data_dict = utils.load_static_data(args.dataset) + static_data_dict = utils.load_static_data( + self.config_loader.dataset.name + ) for static_data_name, static_data_tensor in static_data_dict.items(): self.register_buffer( static_data_name, static_data_tensor, persistent=False @@ -36,14 +39,11 @@ def __init__(self, args): # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: - self.grid_output_dim = ( - 2 * constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell + # Pred. dim. in grid cell + self.grid_output_dim = 2 * self.config_loader.num_data_vars() else: - self.grid_output_dim = ( - constants.GRID_STATE_DIM - ) # Pred. dim. in grid cell - + # Pred. dim. in grid cell + self.grid_output_dim = self.config_loader.num_data_vars() # Store constant per-variable std.-dev. weighting # Note that this is the inverse of the multiplicative weighting # in wMSE/wMAE @@ -57,11 +57,11 @@ def __init__(self, args): ( self.num_grid_nodes, grid_static_dim, - ) = self.grid_static_features.shape # 63784 = 268x238 + ) = self.grid_static_features.shape self.grid_dim = ( - 2 * constants.GRID_STATE_DIM + 2 * self.config_loader.num_data_vars() + grid_static_dim - + constants.GRID_FORCING_DIM + + self.config_loader.dataset.num_forcing_features ) # Instantiate loss function @@ -95,7 +95,7 @@ def __init__(self, args): def configure_optimizers(self): opt = torch.optim.AdamW( - self.parameters(), lr=self.lr, betas=(0.9, 0.95) + self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) if self.opt_state: opt.load_state_dict(self.opt_state) @@ -246,7 +246,7 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( @@ -294,7 +294,7 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in constants.VAL_STEP_LOG_ERRORS + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss @@ -328,7 +328,9 @@ def test_step(self, batch, batch_idx): spatial_loss = self.loss( prediction, target, pred_std, average_grid=False ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) @@ -399,14 +401,15 @@ def plot_examples(self, batch, n_examples, prediction=None): pred_t[:, var_i], target_t[:, var_i], self.interior_mask[:, 0], + self.config_loader, title=f"{var_name} ({var_unit}), " - f"t={t_i} ({self.step_length*t_i} h)", + f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - constants.PARAM_NAMES_SHORT, - constants.PARAM_UNITS, + self.config_loader.dataset.var_names, + self.config_loader.dataset.var_units, var_vranges, ) ) @@ -417,7 +420,7 @@ def plot_examples(self, batch, n_examples, prediction=None): { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - constants.PARAM_NAMES_SHORT, var_figs + self.config_loader.dataset.var_names, var_figs ) } ) @@ -453,7 +456,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, step_length=self.step_length + metric_tensor, self.config_loader, step_length=self.step_length ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -471,14 +474,14 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): ) # Check if metrics are watched, log exact values for specific vars - if full_log_name in constants.METRICS_WATCH: - for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): - var = constants.PARAM_NAMES_SHORT[var_i] + if full_log_name in self.args.metrics_watch: + for var_i, timesteps in self.args.var_leads_metrics_watch.items(): + var = self.config_loader.dataset.var_nums[var_i] log_dict.update( { f"{full_log_name}_{var}_step_{step}": metric_tensor[ step - 1, var_i - ] # 1-indexed in constants + ] # 1-indexed in data_config for step in timesteps } ) @@ -542,10 +545,11 @@ def on_test_epoch_end(self): vis.plot_spatial_error( loss_map, self.interior_mask[:, 0], - title=f"Test loss, t={t_i} ({self.step_length*t_i} h)", + self.config_loader, + title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", ) for t_i, loss_map in zip( - constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -555,14 +559,14 @@ def on_test_epoch_end(self): # also make without title and save as pdf pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) + vis.plot_spatial_error( + loss_map, self.interior_mask[:, 0], self.config_loader + ) for loss_map in mean_spatial_loss ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip( - constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs - ): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 31715502..836b04ed 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -7,9 +7,6 @@ from torch import nn from tueplots import bundles, figsizes -# First-party -from neural_lam import constants - def load_dataset_stats(dataset_name, device="cpu"): """ @@ -263,11 +260,11 @@ def fractional_plot_bundle(fraction): return bundle -def init_wandb_metrics(wandb_logger): +def init_wandb_metrics(wandb_logger, val_steps): """ Set up wandb metrics to track """ experiment = wandb_logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.VAL_STEP_LOG_ERRORS: + for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") diff --git a/neural_lam/vis.py b/neural_lam/vis.py index cef34a84..2b6abf15 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -4,11 +4,11 @@ import numpy as np # First-party -from neural_lam import constants, utils +from neural_lam import utils @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, title=None, step_length=3): +def plot_error_map(errors, data_config, title=None, step_length=3): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -51,7 +51,7 @@ def plot_error_map(errors, title=None, step_length=3): y_ticklabels = [ f"{name} ({unit})" for name, unit in zip( - constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + data_config.dataset.var_names, data_config.dataset.var_units ) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -63,7 +63,9 @@ def plot_error_map(errors, title=None, step_length=3): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, title=None, vrange=None): +def plot_prediction( + pred, target, obs_mask, data_config, title=None, vrange=None +): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -76,23 +78,25 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, axes = plt.subplots( - 1, 2, figsize=(13, 7), subplot_kw={"projection": constants.LAMBERT_PROJ} + 1, + 2, + figsize=(13, 7), + subplot_kw={"projection": data_config.coords_projection()}, ) # Plot pred and target for ax, data in zip(axes, (target, pred)): ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*constants.GRID_SHAPE).cpu().numpy() + data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( data_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, @@ -112,7 +116,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, title=None, vrange=None): +def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -125,22 +129,22 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange # Set up masking of border region - mask_reshaped = obs_mask.reshape(*constants.GRID_SHAPE) + mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) pixel_alpha = ( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": constants.LAMBERT_PROJ} + figsize=(5, 4.8), + subplot_kw={"projection": data_config.coords_projection()}, ) ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*constants.GRID_SHAPE).cpu().numpy() + error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( error_grid, origin="lower", - extent=constants.GRID_LIMITS, alpha=pixel_alpha, vmin=vmin, vmax=vmax, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index eeefc313..a782806b 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -8,7 +8,7 @@ import torch # First-party -from neural_lam import constants, utils +from neural_lam import utils class WeatherDataset(torch.utils.data.Dataset): @@ -218,9 +218,11 @@ def __getitem__(self, idx): # can roll over to next year, ok because periodicity # Encode as sin/cos + # ! Make this more flexible in a separate create_forcings.py script + seconds_in_year = 365 * 24 * 3600 hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,) year_angle = ( - (second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi + (second_into_year / seconds_in_year) * 2 * torch.pi ) # (sample_len,) datetime_forcing = torch.stack( ( diff --git a/plot_graph.py b/plot_graph.py index 48427d5c..40b2b41d 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -7,7 +7,7 @@ import torch_geometric as pyg # First-party -from neural_lam import utils +from neural_lam import config, utils MESH_HEIGHT = 0.1 MESH_LEVEL_DIST = 0.2 @@ -20,10 +20,10 @@ def main(): """ parser = ArgumentParser(description="Plot graph") parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Datast to load grid coordinates from (default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--graph", @@ -44,6 +44,7 @@ def main(): ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) @@ -62,7 +63,7 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)[ + grid_static_features = utils.load_static_data(config_loader.dataset.name)[ "grid_static_features" ] diff --git a/train_model.py b/train_model.py index 96d21a3f..390da6d4 100644 --- a/train_model.py +++ b/train_model.py @@ -9,7 +9,7 @@ from lightning_fabric.utilities import seed # First-party -from neural_lam import constants, utils +from neural_lam import config, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM from neural_lam.models.hi_lam_parallel import HiLAMParallel @@ -29,14 +29,11 @@ def main(): parser = ArgumentParser( description="Train or evaluate NeurWP models for LAM" ) - - # General options parser.add_argument( - "--dataset", + "--data_config", type=str, - default="meps_example", - help="Dataset, corresponding to name in data directory " - "(default: meps_example)", + default="neural_lam/data_config.yaml", + help="Path to data config file (default: neural_lam/data_config.yaml)", ) parser.add_argument( "--model", @@ -183,8 +180,36 @@ def main(): help="Number of example predictions to plot during evaluation " "(default: 1)", ) + + # Logger Settings + parser.add_argument( + "--wandb_project", + type=str, + default="neural_lam", + help="Wandb project name (default: neural_lam)", + ) + parser.add_argument( + "--val_steps_to_log", + type=list, + default=[1, 2, 3, 5, 10, 15, 19], + help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", + ) + parser.add_argument( + "--metrics_watch", + type=list, + default=[], + help="List of metrics to watch, including any prefix (e.g. val_rmse)", + ) + parser.add_argument( + "--var_leads_metrics_watch", + type=dict, + default={}, + help="Dict with variables and lead times to log watched metrics for", + ) args = parser.parse_args() + config_loader = config.Config.from_file(args.data_config) + # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" @@ -203,7 +228,7 @@ def main(): # Load data train_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=args.ar_steps, split="train", subsample_step=args.step_length, @@ -217,7 +242,7 @@ def main(): max_pred_length = (65 // args.step_length) - 2 # 19 val_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="val", subsample_step=args.step_length, @@ -264,7 +289,7 @@ def main(): save_last=True, ) logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, name=run_name, config=args + project=args.wandb_project, name=run_name, config=args ) trainer = pl.Trainer( max_epochs=args.epochs, @@ -280,7 +305,9 @@ def main(): # Only init once, on rank 0 only if trainer.global_rank == 0: - utils.init_wandb_metrics(logger) # Do after wandb.init + utils.init_wandb_metrics( + logger, args.val_steps_to_log + ) # Do after wandb.init if args.eval: if args.eval == "val": @@ -288,7 +315,7 @@ def main(): else: # Test eval_loader = torch.utils.data.DataLoader( WeatherDataset( - args.dataset, + config_loader.dataset.name, pred_length=max_pred_length, split="test", subsample_step=args.step_length,