From b80d010a5a2219a0636ec66f7ffb56753ab33ec7 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 24 Jan 2025 17:35:17 +0100 Subject: [PATCH] first implementation of write-to-zarr during eval --- neural_lam/models/ar_model.py | 129 ++++++++++++++++++++-------------- neural_lam/train_model.py | 5 ++ 2 files changed, 81 insertions(+), 53 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index ab7f5b86..dc8cfa05 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -9,11 +9,13 @@ import torch import wandb import xarray as xr +from loguru import logger # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore +from ..datastore.base import BaseRegularGridDatastore from ..loss_weighting import get_state_feature_weighting from ..weather_dataset import WeatherDataset @@ -184,7 +186,7 @@ def _create_dataarray_from_tensor( weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor.cpu().numpy(), time=time, category=category + tensor=tensor, time=time, category=category ) return da @@ -371,6 +373,70 @@ def on_validation_epoch_end(self): for metric_list in self.val_metrics.values(): metric_list.clear() + def _save_predictions_to_zarr( + self, + batch_times: torch.Tensor, + batch_predictions: torch.Tensor, + batch_idx: int, + zarr_output_path: str, + ): + """ + Save state predictions for single batch to zarr dataset. Will append to + existing dataset for batch_idx > 0. Resulting dataset will contain a + variable named `state` with coordinates (start_time, + elapsed_forecast_duration, grid_index, state_feature). + + Parameters + ---------- + batch_times : torch.Tensor[int] + The times for the batch, given as epoch time in nanoseconds. Shape + is (B, args.pred_steps) where B is the batch size and + args.pred_steps is the number of prediction steps. + batch_predictions : torch.Tensor[float] + The predictions for the batch, given as (B, args.pred_steps, + num_grid_nodes, d_f) where B is the batch size, args.pred_steps is + the number of prediction steps, num_grid_nodes is the number of + grid nodes, and d_f is the number of state features. + batch_idx : int + The index of the batch in the current epoch. + """ + batch_size = batch_predictions.shape[0] + # Convert predictions to DataArray using _create_dataarray_from_tensor + das_pred = [] + for i in range(len(batch_times)): + da_pred = self._create_dataarray_from_tensor( + tensor=batch_predictions[i], + time=batch_times[i], + split="test", + category="state", + ) + # Unstack grid coords if necessary, this also avoids the need to + # try to store a MultiIndex zarr dataset which is not supported by + # xarray + if isinstance(self._datastore, BaseRegularGridDatastore): + da_pred = self._datastore.unstack_grid_coords(da_pred) + + t0 = da_pred.coords["time"].values[0] + da_pred.coords["start_time"] = t0 + da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0 + da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"}) + da_pred.name = "state" + das_pred.append(da_pred) + + da_pred_batch = xr.concat(das_pred, dim="start_time") + + # Apply chunking along analysis_time so that each batch is saved as a + # separate chunk + da_pred_batch = da_pred_batch.chunk({"start_time": batch_size}) + + if batch_idx == 0: + logger.info(f"Saving predictions to {zarr_output_path}") + da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True) + else: + da_pred_batch.to_zarr( + zarr_output_path, mode="a", append_dim="start_time" + ) + # pylint: disable-next=unused-argument def test_step(self, batch, batch_idx): """ @@ -378,8 +444,8 @@ def test_step(self, batch, batch_idx): """ # TODO Here batch_times can be used for plotting routines prediction, target, pred_std, batch_times = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( self.loss( @@ -435,56 +501,13 @@ def test_step(self, batch, batch_idx): self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) - # Convert predictions to DataArray using _create_dataarray_from_tensor - prediction_da = self._create_dataarray_from_tensor(prediction, batch_times, "predictions") - - # Extract dimensions and coordinates from prediction_da for the Dataset - elapsed_forecast_duration = np.arange(prediction.shape[1]) # TODO:Forecast steps - grid_index = np.arange(prediction.shape[2]) # TODO:Spatial grid points - state_features = [f"feature_{i}" for i in range(prediction.shape[-1])] # TODO: State features - - # Create Dataset with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] - ds_prediction = xr.Dataset( - { - "state": ( - ["analysis_time", "elapsed_forecast_duration", "grid_index", "state_feature"], - prediction_da.values, # Use values from the DataArray - ), - }, - coords={ - "analysis_time": prediction_da.coords["analysis_time"].values, - "elapsed_forecast_duration": elapsed_forecast_duration, - "grid_index": grid_index, - "state_feature": state_features, - }, - attrs={ - "description": "Predictions from ARModel", - "model": self.hparams.model_name, - }, - ) - - # Apply chunking along analysis_time - ds_prediction = ds_prediction.chunk({"analysis_time": 1}) - - # Save predictions to Zarr - forecast_save_path = "path/to/my/directory" - if forecast_save_path: - #if self.args.forecast_save_path: - - #zarr_output_path = self.args.forecast_save_path - zarr_output_path = forecast_save_path - # Ensure the output directory exists - os.makedirs(os.path.dirname(zarr_output_path), exist_ok=True) - - # Save or append to Zarr using region - if batch_idx == 0: - ds_prediction.to_zarr(zarr_output_path, mode="w", consolidated=True) - else: - ds_prediction.to_zarr( - zarr_output_path, - mode="a", - region={"analysis_time": slice(batch_idx, batch_idx + 1)}, - ) + if self.args.save_eval_to_zarr_path: + self._save_predictions_to_zarr( + batch_times=batch_times, + batch_predictions=prediction, + batch_idx=batch_idx, + zarr_output_path=self.args.save_eval_to_zarr_path, + ) # Plot example predictions (on rank 0 only) if ( diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..acbf4bca 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -149,6 +149,11 @@ def main(input_args=None): help="Eval model on given data split (val/test) " "(default: None (train model))", ) + parser.add_argument( + "--save-eval-to-zarr-path", + type=str, + help="Save evaluation results to zarr dataset at given path ", + ) parser.add_argument( "--ar_steps_eval", type=int,