Skip to content

Commit

Permalink
first implementation of write-to-zarr during eval
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jan 24, 2025
1 parent 2f5c32e commit b80d010
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 53 deletions.
129 changes: 76 additions & 53 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import torch
import wandb
import xarray as xr
from loguru import logger

# Local
from .. import metrics, vis
from ..config import NeuralLAMConfig
from ..datastore import BaseDatastore
from ..datastore.base import BaseRegularGridDatastore
from ..loss_weighting import get_state_feature_weighting
from ..weather_dataset import WeatherDataset

Expand Down Expand Up @@ -184,7 +186,7 @@ def _create_dataarray_from_tensor(
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor.cpu().numpy(), time=time, category=category
tensor=tensor, time=time, category=category
)
return da

Expand Down Expand Up @@ -371,15 +373,79 @@ def on_validation_epoch_end(self):
for metric_list in self.val_metrics.values():
metric_list.clear()

def _save_predictions_to_zarr(
self,
batch_times: torch.Tensor,
batch_predictions: torch.Tensor,
batch_idx: int,
zarr_output_path: str,
):
"""
Save state predictions for single batch to zarr dataset. Will append to
existing dataset for batch_idx > 0. Resulting dataset will contain a
variable named `state` with coordinates (start_time,
elapsed_forecast_duration, grid_index, state_feature).
Parameters
----------
batch_times : torch.Tensor[int]
The times for the batch, given as epoch time in nanoseconds. Shape
is (B, args.pred_steps) where B is the batch size and
args.pred_steps is the number of prediction steps.
batch_predictions : torch.Tensor[float]
The predictions for the batch, given as (B, args.pred_steps,
num_grid_nodes, d_f) where B is the batch size, args.pred_steps is
the number of prediction steps, num_grid_nodes is the number of
grid nodes, and d_f is the number of state features.
batch_idx : int
The index of the batch in the current epoch.
"""
batch_size = batch_predictions.shape[0]
# Convert predictions to DataArray using _create_dataarray_from_tensor
das_pred = []
for i in range(len(batch_times)):
da_pred = self._create_dataarray_from_tensor(
tensor=batch_predictions[i],
time=batch_times[i],
split="test",
category="state",
)
# Unstack grid coords if necessary, this also avoids the need to
# try to store a MultiIndex zarr dataset which is not supported by
# xarray
if isinstance(self._datastore, BaseRegularGridDatastore):
da_pred = self._datastore.unstack_grid_coords(da_pred)

t0 = da_pred.coords["time"].values[0]
da_pred.coords["start_time"] = t0
da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0
da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"})
da_pred.name = "state"
das_pred.append(da_pred)

da_pred_batch = xr.concat(das_pred, dim="start_time")

# Apply chunking along analysis_time so that each batch is saved as a
# separate chunk
da_pred_batch = da_pred_batch.chunk({"start_time": batch_size})

if batch_idx == 0:
logger.info(f"Saving predictions to {zarr_output_path}")
da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True)
else:
da_pred_batch.to_zarr(
zarr_output_path, mode="a", append_dim="start_time"
)

# pylint: disable-next=unused-argument
def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)

time_step_loss = torch.mean(
self.loss(
Expand Down Expand Up @@ -435,56 +501,13 @@ def test_step(self, batch, batch_idx):
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

# Convert predictions to DataArray using _create_dataarray_from_tensor
prediction_da = self._create_dataarray_from_tensor(prediction, batch_times, "predictions")

# Extract dimensions and coordinates from prediction_da for the Dataset
elapsed_forecast_duration = np.arange(prediction.shape[1]) # TODO:Forecast steps
grid_index = np.arange(prediction.shape[2]) # TODO:Spatial grid points
state_features = [f"feature_{i}" for i in range(prediction.shape[-1])] # TODO: State features

# Create Dataset with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature]
ds_prediction = xr.Dataset(
{
"state": (
["analysis_time", "elapsed_forecast_duration", "grid_index", "state_feature"],
prediction_da.values, # Use values from the DataArray
),
},
coords={
"analysis_time": prediction_da.coords["analysis_time"].values,
"elapsed_forecast_duration": elapsed_forecast_duration,
"grid_index": grid_index,
"state_feature": state_features,
},
attrs={
"description": "Predictions from ARModel",
"model": self.hparams.model_name,
},
)

# Apply chunking along analysis_time
ds_prediction = ds_prediction.chunk({"analysis_time": 1})

# Save predictions to Zarr
forecast_save_path = "path/to/my/directory"
if forecast_save_path:
#if self.args.forecast_save_path:

#zarr_output_path = self.args.forecast_save_path
zarr_output_path = forecast_save_path
# Ensure the output directory exists
os.makedirs(os.path.dirname(zarr_output_path), exist_ok=True)

# Save or append to Zarr using region
if batch_idx == 0:
ds_prediction.to_zarr(zarr_output_path, mode="w", consolidated=True)
else:
ds_prediction.to_zarr(
zarr_output_path,
mode="a",
region={"analysis_time": slice(batch_idx, batch_idx + 1)},
)
if self.args.save_eval_to_zarr_path:
self._save_predictions_to_zarr(
batch_times=batch_times,
batch_predictions=prediction,
batch_idx=batch_idx,
zarr_output_path=self.args.save_eval_to_zarr_path,
)

# Plot example predictions (on rank 0 only)
if (
Expand Down
5 changes: 5 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def main(input_args=None):
help="Eval model on given data split (val/test) "
"(default: None (train model))",
)
parser.add_argument(
"--save-eval-to-zarr-path",
type=str,
help="Save evaluation results to zarr dataset at given path ",
)
parser.add_argument(
"--ar_steps_eval",
type=int,
Expand Down

0 comments on commit b80d010

Please sign in to comment.