Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add writing to zarr dataset for eval-mode of trained models #104

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

leifdenby
Copy link
Member

Describe your changes

Adds new CLI flag to neural_lam.train_model called --save-eval-to-zarr-path <path-to-dataset> which can be added when running neural-lam in eval mode (i.e. neural_lan.train_model --eval ...) to write the predictions to a zarr dataset stored in <path-to-dataset>. This functionality is motivated by our want to be able to store model predictions for later verification.

Example of usage:

Model trained with

$> pdm run python -m neural_lam.train_model --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml --hidden_dim 2 --epochs 1 --ar_steps_train 3 --ar_steps_eval 3 --graph 1level

used for inference with

$> pdm run python -m neural_lam.train_model \
   --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml \
   --hidden_dim 2 --epochs 1 --ar_steps_train 1 --ar_steps_eval 3 --eval val \
   --load saved_models/train-graph_lam-4x2-01_24_17-2502/min_val_loss.ckpt \
   --val_steps_to_log 3 --graph 1level --save-eval-to-zarr-path state_predictions.zarr/

results in:

$> zarrdump state_predictions.zarr
<xarray.Dataset> Size: 123MB
Dimensions:                    (elapsed_forecast_duration: 3, start_time: 11,
                                state_feature: 2, x: 789, y: 589)
Coordinates:
  * elapsed_forecast_duration  (elapsed_forecast_duration) timedelta64[ns] 24B ...
  * start_time                 (start_time) datetime64[ns] 88B 1990-09-07T06:...
  * state_feature              (state_feature) <U5 40B 'u100m' 'v100m'
    time                       (start_time, elapsed_forecast_duration) datetime64[ns] 264B dask.array<chunksize=(4, 3), meta=np.ndarray>
  * x                          (x) float64 6kB -1.999e+06 ... -2.925e+04
  * y                          (y) float64 5kB -6.095e+05 ... 8.605e+05
Data variables:
    state                      (start_time, elapsed_forecast_duration, state_feature, x, y) float32 123MB dask.array<chunksize=(4, 3, 2, 789, 589), meta=np.ndarray>

NB: This does not implement the inversion of the transformations that take place in mllam-data-prep (e.g. splitting individual features back into separate variables and levels. Also, the zarr datasets store time as [start_time, elapsed_forecast_duration] rather than [start_time, sample] to avoid producing a large array with many empty-values (NaNs) which would otherwise happen because each sample has a different start time. In the snippet below I have demonstrated how one could return to absolute time (probably there is a better way to do this...):

import xarray as xr
import matplotlib.pyplot as plt
ds = xr.open_zarr("state_predictions.zarr/", chunks={})

ds.state.isel(x=0, y=0, start_time=slice(0, 4)).plot(hue="start_time", col="state_feature")
plt.savefig("state_predictions_relative_time.png")

ds_abs_time = xr.concat([
    ds.isel(start_time=i).swap_dims(dict(elapsed_forecast_duration="time")) for i in range(len(ds.start_time))
], dim="sample")
ds_abs_time.state.isel(x=0, y=0).plot(hue="sample", col="state_feature")
plt.savefig("state_predictions_absolute_time.png")

Example plot with time-axis showing elapsed time:
state_predictions_relative_time

Example plot with time-axis showing absolute time:
state_predictions_absolute_time

This probably needs more work, but I think it is ready for people to try it out and let me know what they think 😄

Issue Link

Implements #89

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@sadamov sadamov linked an issue Jan 25, 2025 that may be closed by this pull request
@leifdenby leifdenby added the enhancement New feature or request label Feb 6, 2025
@SimonKamuk
Copy link
Contributor

I haven't tested the code myself yet, but this looks great! One question, is test_step called on all gpus in an multi-gpu setup? If so, i guess multiple gpus writing to the same zarr is not an issue as long as the chunks are different, but the metadata is only written correctly, if the GPU which handles batch_idx 0 finishes first, right?

@leifdenby
Copy link
Member Author

leifdenby commented Feb 10, 2025

! One question, is test_step called on all gpus in an multi-gpu setup? If so, i guess multiple gpus writing to the same zarr is not an issue as long as the chunks are different, but the metadata is only written correctly, if the GPU which handles batch_idx 0 finishes first, right?

That is a good point. I haven't actually tested writing on multiple with interference run across multiple GPUs in parallel. There could be a race condition with the fact that I create the zarr dataset (i.e. write the meta info with ds.to_zarr(path, mode="w") only for batch_idx==0 and for other batch_idx I append with ds.to_zarr(path, mode="a") https://github.com/mllam/neural-lam/pull/104/files#diff-043440a2d7a2cd62bb349e74c9fe4f55e69a8d0e4801f6ab2e150d381b11c74dR429. So if batch_idx != 0 is ready before batch_idx == 0 then I think this could might cause an exception to be raised (since there will be no dataset to append to). I am actually not sure how to do this correctly. @observingClouds do you know? I started by looking into the xr.to_zarr(.., range=...) range argument, but for that I think you also have to write the zarr meta data first.

Could we get away with issuing a warning to say this has only been tested for single GPU inference so far?

@leifdenby leifdenby added this to the v0.5.0 milestone Feb 10, 2025
@observingClouds
Copy link
Contributor

@leifdenby do we know all available output times at the time of starting the write process? If so we can create the metadata first and chunks can be written independently without conflicting.

@leifdenby
Copy link
Member Author

do we know all available output times at the time of starting the write process

Yes, we should be able to get these from the datastore. So you would go with using the range argument then? Also, does that mean that we'd create the metadata on say the rank 0 process first and then each separate process would write its own ranges? That would mean introducing something like an MPI barrier after the metadata write, right @observingClouds?

@joeloskarsson
Copy link
Collaborator

does that mean that we'd create the metadata on say the rank 0 process first and then each separate process would write its own ranges

I think this sounds like a good idea. You could use something like https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier for that.

Am interested in having this work with multi-gpu, so I don't think we should just issue a warning and ignore the multi-gpu case. Something that might complicate things (if we allow batch_size > 1) is the note from https://github.com/joeloskarsson/neural-lam-dev?tab=readme-ov-file#evaluate-models.

Note: While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the DistributedSampler will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. A possible workaround is to just use batch size 1 during evaluation. This issue stems from PyTorch Lightning. See for example Lightning-AI/torchmetrics#1886 for more discussion.

If samples are duplicated you could end up with different processes writing to the same region. So that is something to think about.

@leifdenby
Copy link
Member Author

Am interested in having this work with multi-gpu, so I don't think we should just issue a warning and ignore the multi-gpu case. Something that might complicate things (if we allow batch_size > 1) is the note from https://github.com/joeloskarsson/neural-lam-dev?tab=readme-ov-file#evaluate-models.

Ok, things are never as easy as they seem 😆

What I am proposing is that we merge this single-GPU implementation (so that people can start using it) and add multi-GPU in a later PR when we have figured out how to do that. If we did that I would issue a warning with just the single-GPU implementation

@sadamov
Copy link
Collaborator

sadamov commented Feb 13, 2025

I have implemented a version that works with multi-gpu, using region="auto" after writing the full metadata initially. you can find the code here: https://github.com/sadamov/neural-lam/tree/write_zarr

And here is the output of a --eval test on the danra test datastore:
image

Shortcomings:

  • The chunking breaks down for large datastores and need some work (manual chunking?) - could be related to the reindexing step
  • the logging is too verbose and not helpful

Thought this might be useful here, or in the follow-up PR.

Comment on lines +416 to +418
t0 = da_pred.coords["time"].values[0]
da_pred.coords["start_time"] = t0
da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have another look over how target times are used here. batch_times being fed in here are not [analysis_time, analysis_time+time_step, analysis_time+2time_step, ...], but rather [analysis_time+time_step, analysis_time+2time_step, ...]. This also because batch_predictions does not include the state at the analysis time.

This means that start_time is currently not when the forecast was started.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adjusting t0 would fix this: joeloskarsson@721ac5e However, you probably need to get the step length from somewhere else than in that commit, as self.step_length is not present on main.

# Convert predictions to DataArray using _create_dataarray_from_tensor
das_pred = []
for i in range(len(batch_times)):
da_pred = self._create_dataarray_from_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that _create_dataarray_from_tensor is also used here it seems unreasonable to not properly deal with the hack where a WeatherDataset is instantiatied at each call (

# TODO: creating an instance of WeatherDataset here on every call is
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
). As is we are doing instatiation of O(NT) WeatherDatasets when saving to zarr (not a memory problem, as we throw them away, but very wastefull). This will quickly become a problem, as the WeatherDatasets will grow when we merge in more boundar-related changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an alternative hack to avoid constantly making new datasets: joeloskarsson@d277b08 however, this is still very much a hack and the TODO remains that we should handle this some proper way.


if batch_idx == 0:
logger.info(f"Saving predictions to {zarr_output_path}")
da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting erroneous start_times with this, due to problems with their encoding. This seems to fix it for me

Suggested change
da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True)
da_pred_batch.to_zarr(
zarr_output_path,
mode="w",
consolidated=True,
encoding={
"start_time": {
"units": "Seconds since 1970-01-01 00:00:00",
"dtype": "int64",
},
},
)

, setting it to use unix standard time. But I don't know if that is the best solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joeloskarsson which zarr version are you using?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was with zarr 2.18.3, have not tested with zarr 3

t0 = da_pred.coords["time"].values[0]
da_pred.coords["start_time"] = t0
da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0
da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to leave the time coordinate in this DataArray? It looks to me like those values will only be valid for one of the forecasts (and I am not entirely sure which one).

if self.args.save_eval_to_zarr_path:
self._save_predictions_to_zarr(
batch_times=batch_times,
batch_predictions=prediction,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These predictions are in the standardized scale. At some point before these are written to disk in the zarr they should be rescaled to the original data scale.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done as joeloskarsson@d3f636e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Output predictions as zarr dataset
6 participants