diff --git a/.cirun.yml b/.cirun.yml new file mode 100644 index 00000000..21b03ab4 --- /dev/null +++ b/.cirun.yml @@ -0,0 +1,16 @@ +# setup for using github runners via https://cirun.io/ +runners: + - name: "aws-runner" + # Cloud Provider: AWS + cloud: "aws" + # https://aws.amazon.com/ec2/instance-types/g4/ + instance_type: "g4ad.xlarge" + # Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04), Frankfurt region + machine_image: "ami-0ba41b554b28d24a4" + # use Frankfurt region + region: "eu-central-1" + preemptible: false + # Add this label in the "runs-on" param in .github/workflows/.yml + # So that this runner is created for running the workflow + labels: + - "cirun-aws-runner" diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..b02dd545 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +ignore = E203, F811, I002, W503 diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..9d4aeb54 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,52 @@ +## Describe your changes + +< Summary of the changes.> + +< Please also include relevant motivation and context. > + +< List any dependencies that are required for this change. > + +## Issue Link + +< Link to the relevant issue or task. > (e.g. `closes #00` or `solves #00`) + +## Type of change + +- [ ] 🐛 Bug fix (non-breaking change that fixes an issue) +- [ ] ✨ New feature (non-breaking change that adds functionality) +- [ ] 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] 📖 Documentation (Addition or improvements to documentation) + +## Checklist before requesting a review + +- [ ] My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use `pull` with `--rebase` option if possible). +- [ ] I have performed a self-review of my code +- [ ] For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values +- [ ] I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code +- [ ] I have updated the [README](README.MD) to cover introduced code changes +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] I have given the PR a name that clearly describes the change, written in imperative form ([context](https://www.gitkraken.com/learn/git/best-practices/git-commit-message#using-imperative-verb-form)). +- [ ] I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee. + +## Checklist for reviewers + +Each PR comes with its own improvements and flaws. The reviewer should check the following: +- [ ] the code is readable +- [ ] the code is well tested +- [ ] the code is documented (including return types and parameters) +- [ ] the code is easy to maintain + +## Author checklist after completed review + +- [ ] I have added a line to the CHANGELOG describing this change, in a section + reflecting type of change (add section where missing): + - *added*: when you have added new functionality + - *changed*: when default behaviour of the code has been changed + - *fixes*: when your contribution fixes a bug + +## Checklist for assignee + +- [ ] PR is up to date with the base branch +- [ ] the tests pass +- [ ] author has added an entry to the changelog (and designated the change as *added*, *changed* or *fixed*) +- Once the PR is ready to be merged, squash commits and merge the PR. diff --git a/.github/workflows/ci-pdm-install-and-test-cpu.yml b/.github/workflows/ci-pdm-install-and-test-cpu.yml new file mode 100644 index 00000000..8fb4df79 --- /dev/null +++ b/.github/workflows/ci-pdm-install-and-test-cpu.yml @@ -0,0 +1,55 @@ +# cicd workflow for running tests with pytest +# needs to first install pdm, then install torch cpu manually and then install the package +# then run the tests + +name: test (pdm install, cpu) + +on: [push, pull_request] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Install pdm + run: | + python -m pip install pdm + + - name: Create venv + run: | + pdm venv create --with-pip + pdm use --venv in-project + + - name: Install torch (CPU) + run: | + pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + # check that the CPU version is installed + + - name: Install package (including dev dependencies) + run: | + pdm install --group :all + + - name: Print and check torch version + run: | + pdm run python -c "import torch; print(torch.__version__)" + pdm run python -c "import torch; assert torch.__version__.endswith('+cpu')" + + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + + - name: Run tests + run: | + pdm run pytest -vv -s + + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml new file mode 100644 index 00000000..54ab438b --- /dev/null +++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml @@ -0,0 +1,60 @@ +# cicd workflow for running tests with pytest +# needs to first install pdm, then install torch cpu manually and then install the package +# then run the tests + +name: test (pdm install, gpu) + +on: [push, pull_request] + +jobs: + tests: + runs-on: "cirun-aws-runner--${{ github.run_id }}" + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: 3.10 + + - name: Install pdm + run: | + python -m pip install pdm + + - name: Create venv + run: | + pdm config venv.in_project False + pdm config venv.location /opt/dlami/nvme/venv + pdm venv create --with-pip + + - name: Install torch (GPU CUDA 12.1) + run: | + pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu121 + + - name: Print and check torch version + run: | + pdm run python -c "import torch; print(torch.__version__)" + pdm run python -c "import torch; assert not torch.__version__.endswith('+cpu')" + + - name: Install package (including dev dependencies) + run: | + pdm install --group :all + + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + + - name: Run tests + run: | + pdm run pytest -vv -s + + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 diff --git a/.github/workflows/ci-pip-install-and-test-cpu.yml b/.github/workflows/ci-pip-install-and-test-cpu.yml new file mode 100644 index 00000000..b131596d --- /dev/null +++ b/.github/workflows/ci-pip-install-and-test-cpu.yml @@ -0,0 +1,45 @@ +# cicd workflow for running tests with pytest +# needs to first install pdm, then install torch cpu manually and then install the package +# then run the tests + +name: test (pip install, cpu) + +on: [push, pull_request] + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Install torch (CPU) + run: | + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Install package (including dev dependencies) + run: | + python -m pip install ".[dev]" + + - name: Print and check torch version + run: | + python -c "import torch; print(torch.__version__)" + python -c "import torch; assert torch.__version__.endswith('+cpu')" + + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + + - name: Run tests + run: | + python -m pytest -vv -s + + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml new file mode 100644 index 00000000..efda1857 --- /dev/null +++ b/.github/workflows/ci-pip-install-and-test-gpu.yml @@ -0,0 +1,50 @@ +# cicd workflow for running tests with pytest +# needs to first install pdm, then install torch cpu manually and then install the package +# then run the tests + +name: test (pip install, gpu) + +on: [push, pull_request] + +jobs: + tests: + runs-on: "cirun-aws-runner--${{ github.run_id }}" + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: 3.10 + + - name: Install torch (GPU CUDA 12.1) + run: | + python -m pip install torch --index-url https://download.pytorch.org/whl/cu121 + + - name: Install package (including dev dependencies) + run: | + python -m pip install ".[dev]" + + - name: Print and check torch version + run: | + python -c "import torch; print(torch.__version__)" + python -c "import torch; assert not torch.__version__.endswith('+cpu')" + + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.2.0 + + - name: Run tests + run: | + python -m pytest -vv -s + + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip + key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index a6ad84f1..4e12c314 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,33 +1,23 @@ -name: Run pre-commit job +name: Linting on: - push: + # trigger on pushes to any branch + push: + # and also on PRs to main + pull_request: branches: - - main - pull_request: - branches: - - main + - main jobs: - pre-commit-job: + pre-commit-job: runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} + strategy: + matrix: + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 - - name: Install pre-commit hooks - run: | - pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \ - --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \ - torch-cluster==1.6.1 torch-geometric==2.3.1 \ - -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html - - name: Run pre-commit hooks - run: | - pre-commit run --all-files + python-version: ${{ matrix.python-version }} + - uses: pre-commit/action@v2.0.3 diff --git a/.gitignore b/.gitignore index 00899e1e..fdb51d3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,15 @@ ### Project Specific ### wandb -slurm_log* saved_models +lightning_logs data graphs *.sif sweeps -test_*.sh -cosmo_hilam.html -.gitignore .vscode +*.html *.zarr -*.png -shps +*slurm* ### Python ### # Byte-compiled / optimized / DLL files @@ -79,4 +76,14 @@ tags # Coc configuration directory .vim .vscode -boundary_mask_donut.png + +# macos +.DS_Store +__MACOSX + +# pdm (https://pdm-project.org/en/stable/) +.pdm-python +.venv + +# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm +pdm.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48eca67..dfbf8b60 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,51 +1,38 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: check-ast - - id: check-case-conflict - - id: check-docstring-first - - id: check-symlinks - - id: check-toml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: local + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 hooks: - - id: codespell - name: codespell + - id: codespell description: Check for spelling errors - language: system - entry: codespell -- repo: local + + - repo: https://github.com/psf/black + rev: 22.3.0 hooks: - - id: black - name: black + - id: black description: Format Python code - language: system - entry: black - types_or: [python, pyi] -- repo: local + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - - id: isort - name: isort + - id: isort description: Group and sort Python imports - language: system - entry: isort - types_or: [python, pyi, cython] -- repo: local + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 hooks: - - id: flake8 - name: flake8 + - id: flake8 description: Check Python code for correctness, consistency and adherence to best practices - language: system - entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 - types: [python] -- repo: local - hooks: - - id: pylint - name: pylint - entry: pylint -rn -sn - language: system - types: [python] + additional_dependencies: [Flake8-pyproject] diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..42d81149 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,142 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.2.0...HEAD) + +### Added + +- Introduce Datastores to represent input data from different sources, including zarr and numpy. + [\#66](https://github.com/mllam/neural-lam/pull/66) + @leifdenby @sadamov + +- Implement standardization of static features when loaded in ARModel [\#96](https://github.com/mllam/neural-lam/pull/96) @joeloskarsson + +### Fixed + +- Fix wandb environment variable disabling wandb during tests. Now correctly uses WANDB_MODE=disabled. [\#94](https://github.com/mllam/neural-lam/pull/94) @joeloskarsson + +- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby + +## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0) + +### Added +- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub, including push to main branch. Added caching of test data to speed up running tests. + [\#38](https://github.com/mllam/neural-lam/pull/38) [\#55](https://github.com/mllam/neural-lam/pull/55) + @SimonKamuk + +- Replaced `constants.py` with `data_config.yaml` for data configuration management + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- new metrics (`nll` and `crps_gauss`) and `metrics` submodule, stddiv output option + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- ability to "watch" metrics and log + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- pre-commit setup for linting and formatting + [\#6](https://github.com/joeloskarsson/neural-lam/pull/6), [\#8](https://github.com/joeloskarsson/neural-lam/pull/8) + @sadamov, @joeloskarsson + +- added github pull-request template to ease contribution and review process + [\#53](https://github.com/mllam/neural-lam/pull/53), @leifdenby + +- ci/cd setup for running both CPU and GPU-based testing both with pdm and pip based installs [\#37](https://github.com/mllam/neural-lam/pull/37), @khintz, @leifdenby + +### Changed + +- Clarify routine around requesting reviewer and assignee in PR template + [\#74](https://github.com/mllam/neural-lam/pull/74) + @joeloskarsson + +- Argument Parser updated to use action="store_true" instead of 0/1 for boolean arguments. + (https://github.com/mllam/neural-lam/pull/72) + @ErikLarssonDev + +- Optional multi-core/GPU support for statistics calculation in `create_parameter_weights.py` + [\#22](https://github.com/mllam/neural-lam/pull/22) + @sadamov + +- Robust restoration of optimizer and scheduler using `ckpt_path` + [\#17](https://github.com/mllam/neural-lam/pull/17) + @sadamov + +- Updated scripts and modules to use `data_config.yaml` instead of `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- Added new flags in `train_model.py` for configuration previously in `constants.py` + [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) + @sadamov + +- moved batch-static features ("water cover") into forcing component return by `WeatherDataset` + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +- change validation metric from `mae` to `rmse` + [c14b6b4](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) + @joeloskarsson + +- change RMSE definition to compute sqrt after all averaging + [\#10](https://github.com/joeloskarsson/neural-lam/pull/10) + @joeloskarsson + +### Removed + +- `WeatherDataset(torch.Dataset)` no longer returns "batch-static" component of + training item (only `prev_state`, `target_state` and `forcing`), the batch static features are + instead included in forcing + [\#13](https://github.com/joeloskarsson/neural-lam/pull/13) + @joeloskarsson + +### Maintenance + +- simplify pre-commit setup by 1) reducing linting to only cover static + analysis excluding imports from external dependencies (this will be handled + in build/test cicd action introduced later), 2) pinning versions of linting + tools in pre-commit config (and remove from `requirements.txt`) and 3) using + github action to run pre-commit. + [\#29](https://github.com/mllam/neural-lam/pull/29) + @leifdenby + +- change copyright formulation in license to encompass all contributors + [\#47](https://github.com/mllam/neural-lam/pull/47) + @joeloskarsson + +- Fix incorrect ordering of x- and y-dimensions in comments describing tensor + shapes for MEPS data + [\#52](https://github.com/mllam/neural-lam/pull/52) + @joeloskarsson + +- Cap numpy version to < 2.0.0 (this cap was removed in #37, see below) + [\#68](https://github.com/mllam/neural-lam/pull/68) + @joeloskarsson + +- Remove numpy < 2.0.0 version cap + [\#37](https://github.com/mllam/neural-lam/pull/37) + @leifdenby + +- turn `neural-lam` into a python package by moving all `*.py`-files into the + `neural_lam/` source directory and updating imports accordingly. This means + all cli functions are now invoke through the package name, e.g. `python -m + neural_lam.train_model` instead of `python train_model.py` (and can be done + anywhere once the package has been installed). + [\#32](https://github.com/mllam/neural-lam/pull/32), @leifdenby + +- move from `requirements.txt` to `pyproject.toml` for defining package dependencies. + [\#37](https://github.com/mllam/neural-lam/pull/37), @leifdenby + +- Add slack and new publication info to readme + [\#78](https://github.com/mllam/neural-lam/pull/78) + @joeloskarsson + +## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0) + +First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication +() diff --git a/LICENSE.txt b/LICENSE.txt index 1bb69de2..ed176ba1 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Joel Oskarsson, Tomas Landelius +Copyright (c) 2023 Neural-LAM Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 67d9d9b1..20f09c86 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ +[![slack](https://img.shields.io/badge/slack-join-brightgreen.svg?logo=slack)](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) +![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg?branch=main) +[![test (pdm install, gpu)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-gpu.yml/badge.svg)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-gpu.yml) +[![test (pdm install, cpu)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-cpu.yml/badge.svg)](https://github.com/mllam/neural-lam/actions/workflows/ci-pdm-install-and-test-cpu.yml) +

Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM). +Also global forecasting is possible, but currently on a [different branch](https://github.com/mllam/neural-lam/tree/prob_model_global) ([planned to be merged with main](https://github.com/mllam/neural-lam/issues/63)). The code uses [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning). Graph Neural Networks are implemented using [PyG](https://pyg.org/) and logging is set up through [Weights & Biases](https://wandb.ai/). @@ -12,8 +18,15 @@ The repository contains LAM versions of: * GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794). * The hierarchical model from [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370). -For more information see our paper: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370). -If you use Neural-LAM in your work, please cite: +# Publications +For a more in-depth scientific introduction to machine learning for LAM weather forecasting see the publications listed here. +As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used for these publications. +We retain some paper-specific branches for reproducibility purposes. + + +*If you use Neural-LAM in your work, please cite the relevant paper(s)*. + +#### [Graph-based Neural Weather Prediction for Limited Area Modeling](https://arxiv.org/abs/2309.17370) ``` @inproceedings{oskarsson2023graphbased, title={Graph-based Neural Weather Prediction for Limited Area Modeling}, @@ -22,12 +35,20 @@ If you use Neural-LAM in your work, please cite: year={2023} } ``` -As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper. -See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper. +See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces this workshop paper. -We plan to continue updating this repository as we improve existing models and develop new ones. -Collaborations around this implementation are very welcome. -If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository. +#### [Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks](https://arxiv.org/abs/2406.04759) +``` +@inproceedings{oskarsson2024probabilistic, + title = {Probabilistic Weather Forecasting with Hierarchical Graph Neural Networks}, + author = {Oskarsson, Joel and Landelius, Tomas and Deisenroth, Marc Peter and Lindsten, Fredrik}, + booktitle = {Advances in Neural Information Processing Systems}, + volume = {37}, + year = {2024}, +} +``` +See the branches [`prob_model_lam`](https://github.com/mllam/neural-lam/tree/prob_model_lam) and [`prob_model_global`](https://github.com/mllam/neural-lam/tree/prob_model_global) for revisions of the code that reproduces this paper. +The global and probabilistic models from this paper are not yet fully merged with `main` (see issues [62](https://github.com/mllam/neural-lam/issues/62) and [63](https://github.com/mllam/neural-lam/issues/63)). # Modularity The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models. @@ -42,75 +63,347 @@ Still, some restrictions are inevitable:

-## A note on the limited area setting -Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)). -There are still some parts of the code that is quite specific for the MEPS area use case. -This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants used (`neural_lam/constants.py`). -If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. -We would be happy to support such enhancements. -See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done. +# Installing Neural-LAM + +When installing `neural-lam` you have a choice of either installing with +directly `pip` or using the `pdm` package manager. +We recommend using `pdm` as it makes it easy to add/remove packages while +keeping versions consistent (it automatically updates the `pyproject.toml` +file), makes it easy to handle virtual environments and includes the +development toolchain packages installation too. + +**regarding `torch` installation**: because `torch` creates different package +variants for different CUDA versions and cpu-only support you will need to install +`torch` separately if you don't want the most recent GPU variant that also +expects the most recent version of CUDA on your system. + +We cover all the installation options in our [github actions ci/cd +setup](.github/workflows/) which you can use as a reference. + +## Using `pdm` + +1. Clone this repository and navigate to the root directory. +2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)). +> If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 5. +3. Create a virtual environment for pdm to use with `pdm venv create --with-pip`. +4. Install a specific version of `torch` with `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)). +5. Install the dependencies with `pdm install` (by default this in include the). If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --group dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately. + +## Using `pip` + +1. Clone this repository and navigate to the root directory. +> If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3. +2. Install a specific version of `torch` with `python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)). +3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode and install the development dependencies with `python -m pip install -e ".[dev]"` so you can make changes to the code and see the effects immediately. + # Using Neural-LAM -Below follows instructions on how to use Neural-LAM to train and evaluate models. -## Installation -Follow the steps below to create the necessary python environment. +Once `neural-lam` is installed you will be able to train/evaluate models. For this you will in general need two things: + +1. **Data to train/evaluate the model**. To represent this data we use a concept of + *datastores* in Neural-LAM (see the [Data](#data-the-datastore-and-weatherdataset-classes) section for more details). + In brief, a datastore implements the process of loading data from disk in a + specific format (for example zarr or numpy files) by implementing an + interface that provides the data in a data-structure that can be used within + neural-lam. A datastore is used to create a `pytorch.Dataset`-derived + class that samples the data in time to create individual samples for + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. + +2. **The graph structure** is used to define message-passing GNN layers, + that are trained to emulate fluid flow in the atmosphere over time. The + graph structure is created for a specific datastore. + +Any command you run in neural-lam will include the path to a configuration file +to be used (usually called `config.yaml`). This configuration file defines the +path to the datastore configuration you wish to use and allows you to configure +different aspects about the training and evaluation of the model. + +The path you provide to the neural-lam config (`config.yaml`) also sets the +root directory relative to which all other paths are resolved, as in the parent +directory of the config becomes the root directory. Both the datastores and +graphs you generate are then stored in subdirectories of this root directory. +Exactly how and where a specific datastore expects its source data to be stored +and where it stores its derived data is up to the implementation of the +datastore. + +In general the folder structure assumed in Neural-LAM is as follows (we will +assume you placed `config.yaml` in a folder called `data`): + +``` +data/ +├── config.yaml - Configuration file for neural-lam +├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml +└── graphs/ - Directory containing graphs for training +``` + +And the content of `config.yaml` could in this case look like: +```yaml +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 + t2m: 1.0 + r2m: 1.0 + output_clamping: + lower: + t2m: 0.0 + r2m: 0 + upper: + r2m: 1.0 +``` -1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement. -2. Use python 3.9. -3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system. -4. Install required packages specified in `requirements.txt`. -5. Install PyTorch Geometric version 2.2.0. This can be done by running +For now the neural-lam config only defines few things: + +1. The kind of datastore and the path to its config +2. The weighting of different features in +the loss function. If you don't define the state feature weighting it will default to +weighting all features equally. +3. Valid numerical range for output of each feature.The numerical range of all features default to $]-\infty, \infty[$. + +(This example is taken from the `tests/datastore_examples/mdp` directory.) + + +Below follows instructions on how to use Neural-LAM to train and evaluate +models, with details first given for each kind of datastore implemented +and later the graph generation. Once `neural-lam` has been installed the +general process is: + +1. Run any pre-processing scripts to generate the necessary derived data that your chosen datastore requires +2. Run graph-creation step +3. Train the model + +## Data (the `DataStore` and `WeatherDataset` classes) + +To enable flexibility in what input-data sources can be used with neural-lam, +the input-data representation is split into two parts: + +1. A "datastore" (represented by instances of + [neural_lam.datastore.BaseDataStore](neural_lam/datastore/base.py)) which + takes care of loading a given category (state, forcing or static) and split + (train/val/test) of data from disk and returning it as a `xarray.DataArray`. + The returned data-array is expected to have the spatial coordinates + flattened into a single `grid_index` dimension and all variables and vertical + levels stacked into a feature dimension (named as `{category}_feature`). The + datastore also provides information about the number, names and units of + variables in the data, the boundary mask, normalisation values and grid + information. + +2. A `pytorch.Dataset`-derived class (called + `neural_lam.weather_dataset.WeatherDataset`) which takes care of sampling in + time to create individual samples for training, validation and testing. The + `WeatherDataset` class is also responsible for normalising the values and + returning `torch.Tensor`-objects. + +There are currently two different datastores implemented in the codebase: + +1. `neural_lam.datastore.MDPDatastore` which represents loading of + *training-ready* datasets in zarr format created with the + [mllam-data-prep](https://github.com/mllam/mllam-data-prep) package. + Training-ready refers to the fact that this data has been transformed + (variables have been stacked, spatial coordinates have been flattened, + statistics for normalisation have been calculated, etc) to be ready for + training. `mllam-data-prep` can combine any number of datasets that can be + read with [xarray](https://github.com/pydata/xarray) and the processing can + either be done at run-time or as a pre-processing step before calling + neural-lam. + +2. `neural_lam.datastore.NpyFilesDatastoreMEPS` which reads MEPS data from + `.npy`-files in the format introduced in neural-lam `v0.1.0`. Note that this + datastore is specific to the format of the MEPS dataset, but can act as an + example for how to create similar numpy-based datastores. + +If neither of these options fit your need you can create your own datastore by +subclassing the `neural_lam.datastore.BaseDataStore` class or +`neural_lam.datastore.BaseRegularGridDatastore` class (if your data is stored on +a regular grid) and implementing the abstract methods. + + +### MDP (mllam-data-prep) Datastore - `MDPDatastore` + +With `MDPDatastore` (the mllam-data-prep datastore) all the selection, +transformation and pre-calculation steps that are needed to go from +for example gridded weather data to a format that is optimised for training +in neural-lam, are done in a separate package called +[mllam-data-prep](https://github.com/mllam/mllam-data-prep) rather than in +neural-lam itself. +Specifically, the `mllam-data-prep` datastore configuration (for example +[danra.datastore.yaml](tests/datastore_examples/mdp/danra.datastore.yaml)) +specifies a) what source datasets to read from, b) what variables to select, c) +what transformations of dimensions and variables to make, d) what statistics to +calculate (for normalisation) and e) how to split the data into training, +validation and test sets (see full details about the configuration specification +in the [mllam-data-prep README](https://github.com/mllam/mllam-data-prep)). + +From a datastore configuration `mllam-data-prep` returns the transformed +dataset as an `xr.Dataset` which is then written in zarr-format to disk by +`neural-lam` when the datastore is first initiated (the path of the dataset is +derived from the datastore config, so that from a config named `danra.datastore.yaml` the resulting dataset is stored in `danra.datastore.zarr`). +You can also run `mllam-data-prep` directly to create the processed dataset by providing the path to the datastore configuration file: + +```bash +python -m mllam_data_prep --config data/danra.datastore.yaml ``` -TORCH="2.0.1" -CUDA="cu117" -pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 torch-cluster==1.6.1\ - torch-geometric==2.3.1 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html +If you will be working on a large dataset (on the order of 10GB or more) it +could be beneficial to produce the processed `.zarr` dataset before using it +in neural-lam so that you can do the processing across multiple CPU cores in parallel. This is done by including the `--dask-distributed-local-core-fraction` argument when calling mllam-data-prep to set the fraction of your system's CPU cores that should be used for processing (see the +[mllam-data-prep +README for details](https://github.com/mllam/mllam-data-prep?tab=readme-ov-file#creating-large-datasets-with-daskdistributed)). + +For example: + +```bash +python -m mllam_data_prep --config data/danra.datastore.yaml --dask-distributed-local-core-fraction 0.5 ``` -You will have to adjust the `CUDA` variable to match the CUDA version on your system or to run on CPU. See the [installation webpage](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) for more information. -## Data -Datasets should be stored in a directory called `data`. -See the [repository format section](#format-of-data-directory) for details on the directory structure. +### NpyFiles MEPS Datastore - `NpyFilesDatastoreMEPS` + +Version `v0.1.0` of Neural-LAM was built to train from numpy-files from the +MEPS weather forecasts dataset. +To enable this functionality to live on in later versions of neural-lam we have +built a datastore called `NpyFilesDatastoreMEPS` which implements functionality +to read from these exact same numpy-files. At this stage this datastore class +is very much tied to the MEPS dataset, but the code is written in a way where +it quite easily could be adapted to work with numpy-based weather +forecast/analysis files in future. The full MEPS dataset can be shared with other researchers on request, contact us for this. -A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX). +A tiny subset of the data (named `meps_example`) is available in +`example_data.zip`, which can be downloaded from +[here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX). + Download the file and unzip in the neural-lam directory. -All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`). -Note that this is far too little data to train any useful models, but all scripts can be ran with it. +Graphs used in the initial paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_graph`). +Note that this is far too little data to train any useful models, but all pre-processing and training steps can be run with it. It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues. -## Pre-processing -An overview of how the different scripts and files depend on each other is given in this figure: -

- -

-In order to start training models at least three pre-processing scripts have to be ran: +The following datastore configuration works with the MEPS dataset: + +```yaml +# meps.datastore.yaml +dataset: + name: meps_example + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + var_units: + - Pa + - Pa + - W/m\textsuperscript{2} + - W/m\textsuperscript{2} + - "-" + - "-" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s + - kg/m\textsuperscript{2} + - m\textsuperscript{2}/s\textsuperscript{2} + - m\textsuperscript{2}/s\textsuperscript{2} + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 + remove_state_features_with_index: [15] +grid_shape_state: +- 268 +- 238 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 +``` -* `create_mesh.py` -* `create_grid_features.py` -* `create_parameter_weights.py` +Which you can then use in a neural-lam configuration file like this: + +```yaml +# config.yaml +datastore: + kind: npyfilesmeps + config_path: meps.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + values: + u100m: 1.0 + v100m: 1.0 +``` -### Create graph -Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options). -The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as: +For npy-file based datastores you must separately run the command that creates the variables used for standardization: -* **GC-LAM**: `python create_mesh.py --graph multiscale` -* **Hi-LAM**: `python create_mesh.py --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel) -* **L1-LAM**: `python create_mesh.py --graph 1level --levels 1` +```bash +python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats +``` -The graph-related files are stored in a directory called `graphs`. +### Graph creation + +Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options). +The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as: -### Create remaining static features -To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`. -The main option to set for these is just which dataset to use. +* **GC-LAM**: `python -m neural_lam.create_graph --config_path --name multiscale` +* **Hi-LAM**: `python -m neural_lam.create_graph --config_path --name hierarchical --hierarchical` (also works for Hi-LAM-Parallel) +* **L1-LAM**: `python -m neural_lam.create_graph --config_path --name 1level --levels 1` + +The graph-related files are stored in a directory called `graphs`. ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`. -The W&B project name is set to `neural-lam`, but this can be changed in `neural_lam/constants.py`. +The W&B project name is set to `neural-lam`, but this can be changed in the flags of `python -m neural_lam.train_model` (using argsparse). See the [W&B documentation](https://docs.wandb.ai/) for details. If you would like to login and use W&B, run: @@ -123,15 +416,17 @@ wandb off ``` ## Train Models -Models can be trained using `train_model.py`. -Run `python train_model.py --help` for a full list of training options. +Models can be trained using `python -m neural_lam.train_model --config_path `. +Run `python neural_lam.train_model --help` for a full list of training options. A few of the key ones are outlined below: -* `--dataset`: Which data to train on +* `--config_path`: Path to the configuration for neural-lam (for example in `data/myexperiment/config.yaml`). * `--model`: Which model to train * `--graph`: Which graph to use with the model +* `--epochs`: Number of epochs to train for * `--processor_layers`: Number of GNN layers to use in the processing part of the model -* `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss +* `--ar_steps_train`: Number of time steps to unroll for when making predictions and computing the loss +* `--ar_steps_eval`: Number of time steps to unroll for during validation steps Checkpoints of trained models are stored in the `saved_models` directory. The implemented models are: @@ -139,16 +434,16 @@ The implemented models are: ### Graph-LAM This is the basic graph-based LAM model. The encode-process-decode framework is used with a mesh graph in order to make one-step pedictions. -This model class is used both for the L1-LAM and GC-LAM models from the [paper](https://arxiv.org/abs/2309.17370), only with different graphs. +This model class is used both for the L1-LAM and GC-LAM models from the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling), only with different graphs. To train 1L-LAM use ``` -python train_model.py --model graph_lam --graph 1level ... +python -m neural_lam.train_model --model graph_lam --graph 1level ... ``` To train GC-LAM use ``` -python train_model.py --model graph_lam --graph multiscale ... +python -m neural_lam.train_model --model graph_lam --graph multiscale ... ``` ### Hi-LAM @@ -156,7 +451,7 @@ A version of Graph-LAM that uses a hierarchical mesh graph and performs sequenti To train Hi-LAM use ``` -python train_model.py --model hi_lam --graph hierarchical ... +python -m neural_lam.train_model --model hi_lam --graph hierarchical ... ``` ### Hi-LAM-Parallel @@ -165,66 +460,29 @@ Not included in the paper as initial experiments showed worse results than Hi-LA To train Hi-LAM-Parallel use ``` -python train_model.py --model hi_lam_parallel --graph hierarchical ... +python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ... ``` Checkpoint files for our models trained on the MEPS data are available upon request. ## Evaluate Models -Evaluation is also done using `train_model.py`, but using the `--eval` option. +Evaluation is also done using `python -m neural_lam.train_model --config_path `, but using the `--eval` option. Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data. -Most of the training options are also relevant for evaluation (not `ar_steps`, evaluation always unrolls full forecasts). +Most of the training options are also relevant for evaluation. Some options specifically important for evaluation are: * `--load`: Path to model checkpoint file (`.ckpt`) to load parameters from * `--n_example_pred`: Number of example predictions to plot during evaluation. +* `--ar_steps_eval`: Number of time steps to unroll for during evaluation -**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this. +**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. +A possible workaround is to just use batch size 1 during evaluation. +This issue stems from PyTorch Lightning. See for example [this PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion. # Repository Structure Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory. Model classes, including abstract base classes, are located in `neural_lam/models`. - -## Format of data directory -It is possible to store multiple datasets in the `data` directory. -Each dataset contains a set of files with static features and a set of samples. -The samples are split into different sub-directories for training, validation and testing. -The directory structure is shown with examples below. -Script names within parenthesis denote the script used to generate the file. -``` -data -├── dataset1 -│ ├── samples - Directory with data samples -│ │ ├── train - Training data -│ │ │ ├── nwp_2022040100_mbr000.npy - A time series sample -│ │ │ ├── nwp_2022040100_mbr001.npy -│ │ │ ├── ... -│ │ │ ├── nwp_2022043012_mbr001.npy -│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy - Solar flux forcing -│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy -│ │ │ ├── ... -│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022043012.npy -│ │ │ ├── wtr_2022040100.npy - Open water features for one sample -│ │ │ ├── wtr_2022040112.npy -│ │ │ ├── ... -│ │ │ └── wtr_202204012.npy -│ │ ├── val - Validation data -│ │ └── test - Test data -│ └── static - Directory with graph information and static features -│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset) -│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset) -│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset) -│ ├── grid_features.pt - Static features of grid nodes (create_grid_features.py) -│ ├── parameter_mean.pt - Means of state parameters (create_parameter_weights.py) -│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py) -│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py) -│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py) -│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py) -│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py) -├── dataset2 -├── ... -└── datasetN -``` +Notebooks for visualization and analysis are located in `docs`. ## Format of graph directory The `graphs` directory contains generated graph structures that can be used by different graph-based models. @@ -232,13 +490,13 @@ The structure is shown with examples below: ``` graphs ├── graph1 - Directory with a graph definition -│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py) -│ ├── g2m_edge_index.pt - Edges from grid to mesh (create_mesh.py) -│ ├── m2g_edge_index.pt - Edges from mesh to grid (create_mesh.py) -│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py) -│ ├── g2m_features.pt - Static features of grid to mesh edges (create_mesh.py) -│ ├── m2g_features.pt - Static features of mesh to grid edges (create_mesh.py) -│ └── mesh_features.pt - Static features of mesh nodes (create_mesh.py) +│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh) +│ ├── g2m_edge_index.pt - Edges from grid to mesh (neural_lam.create_mesh) +│ ├── m2g_edge_index.pt - Edges from mesh to grid (neural_lam.create_mesh) +│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh) +│ ├── g2m_features.pt - Static features of grid to mesh edges (neural_lam.create_mesh) +│ ├── m2g_features.pt - Static features of mesh to grid edges (neural_lam.create_mesh) +│ └── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh) ├── graph2 ├── ... └── graphN @@ -248,9 +506,9 @@ graphs To keep track of levels in the mesh graph, a list format is used for the files with mesh graph information. In particular, the files ``` -│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py) -│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py) -│ ├── mesh_features.pt - Static features of mesh nodes (create_mesh.py) +│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh) +│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh) +│ ├── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh) ``` all contain lists of length `L`, for a hierarchical mesh graph with `L` layers. For non-hierarchical graphs `L == 1` and these are all just singly-entry lists. @@ -261,10 +519,10 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w ``` ├── graph1 │ ├── ... -│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (create_mesh.py) -│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (create_mesh.py) -│ ├── mesh_down_features.pt - Static features of downward mesh edges (create_mesh.py) -│ ├── mesh_up_features.pt - Static features of upward mesh edges (create_mesh.py) +│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (neural_lam.create_mesh) +│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (neural_lam.create_mesh) +│ ├── mesh_down_features.pt - Static features of downward mesh edges (neural_lam.create_mesh) +│ ├── mesh_up_features.pt - Static features of upward mesh edges (neural_lam.create_mesh) │ ├── ... ``` These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels). @@ -280,6 +538,8 @@ pre-commit run --all-files ``` from the root directory of the repository. +Furthermore, all tests in the ```tests``` directory will be run upon pushing changes by a github action. Failure in any of the tests will also reject the push/PR. + # Contact -If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. -You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_mesh.py b/create_mesh.py deleted file mode 100644 index 2b6af9fd..00000000 --- a/create_mesh.py +++ /dev/null @@ -1,501 +0,0 @@ -# Standard library -import os -from argparse import ArgumentParser - -# Third-party -import matplotlib -import matplotlib.pyplot as plt -import networkx -import numpy as np -import scipy.spatial -import torch -import torch_geometric as pyg -from torch_geometric.utils.convert import from_networkx - -# First-party -from neural_lam import utils - -# matplotlib.use('TkAgg') - - -def plot_graph(graph, title=None): - fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H - edge_index = graph.edge_index - pos = graph.pos - - # Fix for re-indexed edge indices only containing mesh nodes at - # higher levels in hierarchy - edge_index = edge_index - edge_index.min() - - if pyg.utils.is_undirected(edge_index): - # Keep only 1 direction of edge_index - edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) - # TODO: indicate direction of directed edges - - # Move all to cpu and numpy, compute (in)-degrees - degrees = ( - pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() - ) - edge_index = edge_index.cpu().numpy() - pos = pos.cpu().numpy() - - # Plot edges - from_pos = pos[edge_index[0]] # (M/2, 2) - to_pos = pos[edge_index[1]] # (M/2, 2) - edge_lines = np.stack((from_pos, to_pos), axis=1) - axis.add_collection( - matplotlib.collections.LineCollection( - edge_lines, lw=0.4, colors="black", zorder=1 - ) - ) - - # Plot nodes - node_scatter = axis.scatter( - pos[:, 0], - pos[:, 1], - c=degrees, - s=3, - marker="o", - zorder=2, - cmap="viridis", - clim=None, - ) - - plt.colorbar(node_scatter, aspect=50) - - if title is not None: - axis.set_title(title) - - return fig, axis - - -def sort_nodes_internally(nx_graph): - # For some reason the networkx .nodes() return list can not be sorted, - # but this is the ordering used by pyg when converting. - # This function fixes this. - H = networkx.DiGraph() - H.add_nodes_from(sorted(nx_graph.nodes(data=True))) - H.add_edges_from(nx_graph.edges(data=True)) - return H - - -def save_edges(graph, name, base_path): - torch.save( - graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") - ) - edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def save_edges_list(graphs, name, base_path): - torch.save( - [graph.edge_index for graph in graphs], - os.path.join(base_path, f"{name}_edge_index.pt"), - ) - edge_features = [ - torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32 - ) - for graph in graphs - ] # Save as float32 - torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) - - -def from_networkx_with_start_index(nx_graph, start_index): - pyg_graph = from_networkx(nx_graph) - pyg_graph.edge_index += start_index - return pyg_graph - - -def mk_2d_graph(xy, nx, ny): - xm, xM = np.amin(xy[0][0, :]), np.amax(xy[0][0, :]) - ym, yM = np.amin(xy[1][:, 0]), np.amax(xy[1][:, 0]) - - # avoid nodes on border - dx = (xM - xm) / nx - dy = (yM - ym) / ny - lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) - ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - - mg = np.meshgrid(lx, ly) - g = networkx.grid_2d_graph(len(ly), len(lx)) - - for node in g.nodes: - g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]]) - - # add diagonal edges - g.add_edges_from( - [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - + [ - ((x + 1, y), (x, y + 1)) - for x in range(nx - 1) - for y in range(ny - 1) - ] - ) - - # turn into directed graph - dg = networkx.DiGraph(g) - for u, v in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) - dg.edges[u, v]["len"] = d - dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] - dg.add_edge(v, u) - dg.edges[v, u]["len"] = d - dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] - - return dg - - -def prepend_node_index(graph, new_index): - # Relabel node indices in graph, insert (graph_level, i, j) - ijk = [tuple((new_index,) + x) for x in graph.nodes] - to_mapping = dict(zip(graph.nodes, ijk)) - return networkx.relabel_nodes(graph, to_mapping, copy=True) - - -def main(): - parser = ArgumentParser(description="Graph generation arguments") - parser.add_argument( - "--graph", - type=str, - default="multiscale", - help="Name to save graph as (default: multiscale)", - ) - parser.add_argument( - "--plot", - type=int, - default=0, - help="If graphs should be plotted during generation " - "(default: 0 (false))", - ) - parser.add_argument( - "--levels", - type=int, - help="Limit multi-scale mesh to given number of levels, " - "from bottom up (default: None (no limit))", - ) - parser.add_argument( - "--hierarchical", - type=int, - default=0, - help="Generate hierarchical mesh graph (default: 0, no)", - ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) - - args = parser.parse_args() - - # Load grid positions - graph_dir_path = os.path.join("graphs", args.graph) - os.makedirs(graph_dir_path, exist_ok=True) - - config_loader = utils.ConfigLoader(args.data_config) - xy = config_loader.get_nwp_xy() - grid_xy = torch.tensor(xy) - pos_max = torch.max(torch.abs(grid_xy)) - - # - # Mesh - # - - # graph geometry - nx = 3 # number of children = nx**2 - nlev = int(np.log(max(xy.shape)) / np.log(nx)) - nleaf = nx**nlev # leaves at the bottom = nleaf**2 - - mesh_levels = nlev - 1 - if args.levels: - # Limit the levels in mesh graph - mesh_levels = min(mesh_levels, args.levels) - - print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}") - - # multi resolution tree levels - G = [] - for lev in range(1, mesh_levels + 1): - n = int(nleaf / (nx**lev)) - g = mk_2d_graph(xy, n, n) - if args.plot: - plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") - plt.show() - - G.append(g) - - if args.hierarchical: - # Relabel nodes of each level with level index first - G = [ - prepend_node_index(graph, level_i) - for level_i, graph in enumerate(G) - ] - - num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) - # First node index in each level in the hierarchical graph - first_index_level = np.concatenate( - (np.zeros(1, dtype=int), np.cumsum(num_nodes_level[:-1])) - ) - - # Create inter-level mesh edges - up_graphs = [] - down_graphs = [] - for from_level, to_level, G_from, G_to, start_index in zip( - range(1, mesh_levels), - range(0, mesh_levels - 1), - G[1:], - G[:-1], - first_index_level[: mesh_levels - 1], - ): - # start out from graph at from level - G_down = G_from.copy() - G_down.clear_edges() - G_down = networkx.DiGraph(G_down) - - # Add nodes of to level - G_down.add_nodes_from(G_to.nodes(data=True)) - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - v_to_list = list(G_to.nodes) - v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) - kdt_m = scipy.spatial.KDTree(v_from_xy) - - # add edges from mesh to grid - for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] - u = v_from_list[neigh_idx] - - # add edge from mesh to grid - G_down.add_edge(u, v) - d = np.sqrt( - np.sum( - (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 - ) - ) - G_down.edges[u, v]["len"] = d - G_down.edges[u, v]["vdiff"] = ( - G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_down_int = networkx.convert_node_labels_to_integers( - G_down, first_label=start_index, ordering="sorted" - ) # Issue with sorting here - G_down_int = sort_nodes_internally(G_down_int) - pyg_down = from_networkx_with_start_index(G_down_int, start_index) - - # Create up graph, invert downwards edges - up_edges = torch.stack( - (pyg_down.edge_index[1], pyg_down.edge_index[0]), dim=0 - ) - pyg_up = pyg_down.clone() - pyg_up.edge_index = up_edges - - up_graphs.append(pyg_up) - down_graphs.append(pyg_down) - - if args.plot: - plot_graph( - pyg_down, title=f"Down graph, {from_level} -> {to_level}" - ) - plt.show() - - plot_graph( - pyg_down, title=f"Up graph, {to_level} -> {from_level}" - ) - plt.show() - - # Save up and down edges - save_edges_list(up_graphs, "mesh_up", graph_dir_path) - save_edges_list(down_graphs, "mesh_down", graph_dir_path) - - # Extract intra-level edges for m2m - m2m_graphs = [ - from_networkx_with_start_index( - networkx.convert_node_labels_to_integers( - level_graph, first_label=start_index, ordering="sorted" - ), - start_index, - ) - for level_graph, start_index in zip(G, first_index_level) - ] - - mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] - - # For use in g2m and m2g - G_bottom_mesh = G[0] - - joint_mesh_graph = networkx.union_all([graph for graph in G]) - all_mesh_nodes = joint_mesh_graph.nodes(data=True) - - else: - # combine all levels to one graph - G_tot = G[0] - for lev in range(1, len(G)): - nodes = list(G[lev - 1].nodes) - n = int(np.sqrt(len(nodes))) - ij = ( - np.array(nodes) - .reshape((n, n, 2))[1::nx, 1::nx, :] - .reshape(int(n / nx) ** 2, 2) - ) - ij = [tuple(x) for x in ij] - G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) - G_tot = networkx.compose(G_tot, G[lev]) - - # Relabel mesh nodes to start with 0 - G_tot = prepend_node_index(G_tot, 0) - - # relabel nodes to integers (sorted) - G_int = networkx.convert_node_labels_to_integers( - G_tot, first_label=0, ordering="sorted" - ) - - # Graph to use in g2m and m2g - G_bottom_mesh = G_tot - all_mesh_nodes = G_tot.nodes(data=True) - - # export the nx graph to PyTorch geometric - pyg_m2m = from_networkx(G_int) - m2m_graphs = [pyg_m2m] - mesh_pos = [pyg_m2m.pos.to(torch.float32)] - - if args.plot: - plot_graph(pyg_m2m, title="Mesh-to-mesh") - plt.show() - - # Save m2m edges - save_edges_list(m2m_graphs, "m2m", graph_dir_path) - - # Divide mesh node pos by max coordinate of grid cell - mesh_pos = [pos / pos_max for pos in mesh_pos] - - # Save mesh positions - torch.save( - mesh_pos, os.path.join(graph_dir_path, "mesh_features.pt") - ) # mesh pos, in float32 - - # - # Grid2Mesh - # - - # radius within which grid nodes are associated with a mesh node - # (in terms of mesh distance) - DM_SCALE = 0.67 - - # mesh nodes on lowest level - vm = G_bottom_mesh.nodes - vm_xy = np.array([xy for _, xy in vm.data("pos")]) - # distance between mesh nodes - dm = np.sqrt( - np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) - ) - - # grid nodes - Ny, Nx = xy.shape[1:] - - G_grid = networkx.grid_2d_graph(Ny, Nx) - G_grid.clear_edges() - - # vg features (only pos introduced here) - for node in G_grid.nodes: - # pos is in feature but here explicit for convenience - G_grid.nodes[node]["pos"] = np.array([xy[0][node], xy[1][node]]) - - # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes - # (i,j) and impose sorting order such that vm are the first nodes - G_grid = prepend_node_index(G_grid, 1000) - - # build kd tree for grid point pos - # order in vg_list should be same as in vg_xy - vg_list = list(G_grid.nodes) - vg_xy = np.array([[xy[0][node[1:]], xy[1][node[1:]]] for node in vg_list]) - kdt_g = scipy.spatial.KDTree(vg_xy) - - # now add (all) mesh nodes, include features (pos) - G_grid.add_nodes_from(all_mesh_nodes) - - # Re-create graph with sorted node indices - # Need to do sorting of nodes this way for indices to map correctly to pyg - G_g2m = networkx.Graph() - G_g2m.add_nodes_from(sorted(G_grid.nodes(data=True))) - - # turn into directed graph - G_g2m = networkx.DiGraph(G_g2m) - - # add edges - for v in vm: - # find neighbours (index to vg_xy) - neigh_idxs = kdt_g.query_ball_point(vm[v]["pos"], dm * DM_SCALE) - for i in neigh_idxs: - u = vg_list[i] - # add edge from grid to mesh - G_g2m.add_edge(u, v) - d = np.sqrt( - np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) - ) - G_g2m.edges[u, v]["len"] = d - G_g2m.edges[u, v]["vdiff"] = ( - G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] - ) - - pyg_g2m = from_networkx(G_g2m) - - if args.plot: - plot_graph(pyg_g2m, title="Grid-to-mesh") - plt.show() - - # - # Mesh2Grid - # - - # start out from Grid2Mesh and then replace edges - G_m2g = G_g2m.copy() - G_m2g.clear_edges() - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - vm_list = list(vm) - kdt_m = scipy.spatial.KDTree(vm_xy) - - # add edges from mesh to grid - for v in vg_list: - # find 4 nearest neighbours (index to vm_xy) - neigh_idxs = kdt_m.query(G_m2g.nodes[v]["pos"], 4)[1] - for i in neigh_idxs: - u = vm_list[i] - # add edge from mesh to grid - G_m2g.add_edge(u, v) - d = np.sqrt( - np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) - ) - G_m2g.edges[u, v]["len"] = d - G_m2g.edges[u, v]["vdiff"] = ( - G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] - ) - - # relabel nodes to integers (sorted) - G_m2g_int = networkx.convert_node_labels_to_integers( - G_m2g, first_label=0, ordering="sorted" - ) - pyg_m2g = from_networkx(G_m2g_int) - - if args.plot: - plot_graph(pyg_m2g, title="Mesh-to-grid") - plt.show() - - # Save g2m and m2g everything - # g2m - save_edges(pyg_g2m, "g2m", graph_dir_path) - # m2g - save_edges(pyg_m2g, "m2g", graph_dir_path) - - -if __name__ == "__main__": - main() diff --git a/create_parameter_weights.py b/create_parameter_weights.py deleted file mode 100644 index 1eda7a24..00000000 --- a/create_parameter_weights.py +++ /dev/null @@ -1,125 +0,0 @@ -# Standard library -from argparse import ArgumentParser - -# Third-party -import torch -import xarray as xr -from tqdm import tqdm - -# First-party -from neural_lam.weather_dataset import WeatherDataModule - - -def main(): - """ - Pre-compute parameter weights to be used in loss function - """ - parser = ArgumentParser(description="Training arguments") - parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size when iterating over the dataset", - ) - parser.add_argument( - "--num_workers", - type=int, - default=4, - help="Number of workers in data loader (default: 4)", - ) - parser.add_argument( - "--zarr_path", - type=str, - default="normalization.zarr", - help="Directory where data is stored", - ) - - args = parser.parse_args() - - data_module = WeatherDataModule( - batch_size=args.batch_size, num_workers=args.num_workers - ) - data_module.setup() - loader = data_module.train_dataloader() - - # Load dataset without any subsampling - # Compute mean and std.-dev. of each parameter (+ forcing forcing) - # across full dataset - print("Computing mean and std.-dev. for parameters...") - means = [] - squares = [] - fb_means = {"forcing": [], "boundary": []} - fb_squares = {"forcing": [], "boundary": []} - - for init_batch, target_batch, forcing_batch, boundary_batch, _ in tqdm( - loader - ): - batch = torch.cat( - (init_batch, target_batch), dim=1 - ) # (N_batch, N_t, N_grid, d_features) - means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) - squares.append(torch.mean(batch**2, dim=(1, 2))) - - for fb_type, fb_batch in zip( - ["forcing", "boundary"], [forcing_batch, boundary_batch] - ): - fb_batch = fb_batch[:, :, :, 1] - fb_means[fb_type].append(torch.mean(fb_batch)) # (,) - fb_squares[fb_type].append(torch.mean(fb_batch**2)) # (,) - - mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) - second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) - std = torch.sqrt(second_moment - mean**2) # (d_features) - - fb_stats = {} - for fb_type in ["forcing", "boundary"]: - fb_stats[f"{fb_type}_mean"] = torch.mean( - torch.stack(fb_means[fb_type]) - ) # (,) - fb_second_moment = torch.mean(torch.stack(fb_squares[fb_type])) # (,) - fb_stats[f"{fb_type}_std"] = torch.sqrt( - fb_second_moment - fb_stats[f"{fb_type}_mean"] ** 2 - ) # (,) - - # Compute mean and std.-dev. of one-step differences across the dataset - print("Computing mean and std.-dev. for one-step differences...") - diff_means = [] - diff_squares = [] - for init_batch, target_batch, _, _, _ in tqdm(loader): - # normalize the batch - init_batch = (init_batch - mean) / std - target_batch = (target_batch - mean) / std - - batch = torch.cat((init_batch, target_batch), dim=1) - batch_diffs = batch[:, 1:] - batch[:, :-1] - # (N_batch, N_t-1, N_grid, d_features) - - diff_means.append( - torch.mean(batch_diffs, dim=(1, 2)) - ) # (N_batch', d_features,) - diff_squares.append( - torch.mean(batch_diffs**2, dim=(1, 2)) - ) # (N_batch', d_features,) - - diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) - diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) - diff_std = torch.sqrt(diff_second_moment - diff_mean**2) # (d_features) - - # Create xarray dataset - ds = xr.Dataset( - { - "mean": (["d_features"], mean), - "std": (["d_features"], std), - "diff_mean": (["d_features"], diff_mean), - "diff_std": (["d_features"], diff_std), - **fb_stats, - } - ) - - # Save dataset as Zarr - print("Saving dataset as Zarr...") - ds.to_zarr(args.zarr_path, mode="w") - - -if __name__ == "__main__": - main() diff --git a/docs/notebooks/create_reduced_meps_dataset.ipynb b/docs/notebooks/create_reduced_meps_dataset.ipynb new file mode 100644 index 00000000..daba23c4 --- /dev/null +++ b/docs/notebooks/create_reduced_meps_dataset.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating meps_example_reduced\n", + "This notebook outlines how the small-size test dataset ```meps_example_reduced``` was created based on the slightly larger dataset ```meps_example```. The zipped up datasets are 263 MB and 2.6 GB, respectively. See [README.md](../../README.md) for info on how to download ```meps_example```.\n", + "\n", + "The dataset was reduced in size by reducing the number of grid points and variables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Standard library\n", + "import os\n", + "\n", + "# Third-party\n", + "import numpy as np\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of grid points was reduced to 1/4 by halving the number of coordinates in both the x and y direction. This was done by removing a quarter of the grid points along each outer edge, so the center grid points would stay centered in the new set.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load existing grid\n", + "grid_xy = np.load('data/meps_example/static/nwp_xy.npy')\n", + "# Get slices in each dimension by cutting off a quarter along each edge\n", + "num_x, num_y = grid_xy.shape[1:]\n", + "x_slice = slice(num_x//4, 3*num_x//4)\n", + "y_slice = slice(num_y//4, 3*num_y//4)\n", + "# Index and save reduced grid\n", + "grid_xy_reduced = grid_xy[:, x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/nwp_xy.npy', grid_xy_reduced)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This cut out the border, so a new perimeter of 10 grid points was established as border (10 was also the border size in the original \"meps_example\").\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Outer 10 grid points are border\n", + "old_border_mask = np.load('data/meps_example/static/border_mask.npy')\n", + "assert np.all(old_border_mask[10:-10, 10:-10] == False)\n", + "assert np.all(old_border_mask[:10, :] == True)\n", + "assert np.all(old_border_mask[:, :10] == True)\n", + "assert np.all(old_border_mask[-10:,:] == True)\n", + "assert np.all(old_border_mask[:,-10:] == True)\n", + "\n", + "# Create new array with False everywhere but the outer 10 grid points\n", + "border_mask = np.zeros_like(grid_xy_reduced[0,:,:], dtype=bool)\n", + "border_mask[:10] = True\n", + "border_mask[:,:10] = True\n", + "border_mask[-10:] = True\n", + "border_mask[:,-10:] = True\n", + "np.save('data/meps_example_reduced/static/border_mask.npy', border_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few other files also needed to be copied using only the new reduced grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load surface_geopotential.npy, index only values from the reduced grid, and save to new file\n", + "surface_geopotential = np.load('data/meps_example/static/surface_geopotential.npy')\n", + "surface_geopotential_reduced = surface_geopotential[x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/surface_geopotential.npy', surface_geopotential_reduced)\n", + "\n", + "# Load pytorch file grid_features.pt\n", + "grid_features = torch.load('data/meps_example/static/grid_features.pt')\n", + "# Index only values from the reduced grid. \n", + "# First reshape from (num_grid_points_total, 4) to (num_grid_points_x, num_grid_points_y, 4), \n", + "# then index, then reshape back to new total number of grid points\n", + "print(grid_features.shape)\n", + "grid_features_new = grid_features.reshape(num_x, num_y, 4)[x_slice,y_slice,:].reshape((-1, 4))\n", + "# Save to new file\n", + "torch.save(grid_features_new, 'data/meps_example_reduced/static/grid_features.pt')\n", + "\n", + "# flux_stats.pt is just a vector of length 2, so the grid shape and variable changes does not change this file\n", + "torch.save(torch.load('data/meps_example/static/flux_stats.pt'), 'data/meps_example_reduced/static/flux_stats.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of variables was reduced by truncating the variable list to the first 8." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_vars = 8\n", + "\n", + "# Load parameter_weights.npy, truncate to first 8 variables, and save to new file\n", + "parameter_weights = np.load('data/meps_example/static/parameter_weights.npy')\n", + "parameter_weights_reduced = parameter_weights[:num_vars]\n", + "np.save('data/meps_example_reduced/static/parameter_weights.npy', parameter_weights_reduced)\n", + "\n", + "# Do the same for following 4 pytorch files\n", + "for file in ['diff_mean', 'diff_std', 'parameter_mean', 'parameter_std']:\n", + " old_file = torch.load(f'data/meps_example/static/{file}.pt')\n", + " new_file = old_file[:num_vars]\n", + " torch.save(new_file, f'data/meps_example_reduced/static/{file}.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly the files in each of the directories train, test, and val have to be reduced. The folders all have the same structure with files of the following types:\n", + "```\n", + "nwp_YYYYMMDDHH_mbrXXX.npy\n", + "wtr_YYYYMMDDHH.npy\n", + "nwp_toa_downwelling_shortwave_flux_YYYYMMDDHH.npy\n", + "```\n", + "with ```YYYYMMDDHH``` being some date with hours, and ```XXX``` being some 3-digit integer.\n", + "\n", + "The first type of file has x and y in dimensions 1 and 2, and variable index in dimension 3. Dimension 0 is unchanged.\n", + "The second type has has x and y in dimensions 1 and 2. Dimension 0 is unchanged.\n", + "The last type has just x and y as the only 2 dimensions.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(65, 268, 238, 18)\n", + "(65, 268, 238)\n" + ] + } + ], + "source": [ + "print(np.load('data/meps_example/samples/train/nwp_2022040100_mbr000.npy').shape)\n", + "print(np.load('data/meps_example/samples/train/nwp_toa_downwelling_shortwave_flux_2022040112.npy').shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following loop goes through each file in each sample folder and indexes them according to the dimensions given by the file name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for sample in ['train', 'test', 'val']:\n", + " files = os.listdir(f'data/meps_example/samples/{sample}')\n", + "\n", + " for f in files:\n", + " data = np.load(f'data/meps_example/samples/{sample}/{f}')\n", + " if 'mbr' in f:\n", + " data = data[:,x_slice,y_slice,:num_vars]\n", + " elif 'wtr' in f:\n", + " data = data[x_slice, y_slice]\n", + " else:\n", + " data = data[:,x_slice,y_slice]\n", + " np.save(f'data/meps_example_reduced/samples/{sample}/{f}', data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, the file ```data_config.yaml``` is modified manually by truncating the variable units, long and short names, and setting the new grid shape. Also the unit descriptions containing ```^``` was automatically parsed using latex, and to avoid having to install latex in the GitHub CI/CD pipeline, this was changed to ```**```. \n", + "\n", + "This new config file was placed in ```data/meps_example_reduced```, and that directory was then zipped and placed in a European Weather Cloud S3 bucket." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/figures/component_dependencies.png b/figures/component_dependencies.png deleted file mode 100644 index fae77cab..00000000 Binary files a/figures/component_dependencies.png and /dev/null differ diff --git a/neural_lam/__init__.py b/neural_lam/__init__.py new file mode 100644 index 00000000..da4c4d2e --- /dev/null +++ b/neural_lam/__init__.py @@ -0,0 +1,9 @@ +# First-party +import neural_lam.interaction_net +import neural_lam.metrics +import neural_lam.models +import neural_lam.utils +import neural_lam.vis + +# Local +from .weather_dataset import WeatherDataset diff --git a/neural_lam/build_rectangular_graph.py b/neural_lam/build_rectangular_graph.py new file mode 100644 index 00000000..b22eaae8 --- /dev/null +++ b/neural_lam/build_rectangular_graph.py @@ -0,0 +1,332 @@ +# Standard library +import argparse +import os + +# Third-party +import cartopy.crs as ccrs +import numpy as np +import weather_model_graphs as wmg + +# Local +from . import utils +from .config import load_config_and_datastores + +WMG_ARCHETYPES = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def main(input_args=None): + """ + Build rectangular graph from archetype, using cmd-line arguments. + """ + parser = argparse.ArgumentParser( + description="Rectangular graph generation using weather-models-graph", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Inputs and outputs + parser.add_argument( + "--config_path", + type=str, + help="Path to the configuration for neural-lam", + ) + parser.add_argument( + "--graph_name", + type=str, + help="Name to save graph as (default: multiscale)", + ) + parser.add_argument( + "--output_dir", + type=str, + default="graphs", + help="Directory to save graph to", + ) + + # Graph structure + parser.add_argument( + "--archetype", + type=str, + default="keisler", + help="Archetype to use to create graph " + "(keisler/graphcast/hierarchical)", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=3.0, + help="Distance between created mesh nodes", + ) + parser.add_argument( + "--level_refinement_factor", + type=int, + default=3, + help="Refinement factor between grid points and bottom level of " + "mesh hierarchy", + ) + parser.add_argument( + "--max_num_levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up", + ) + args = parser.parse_args(input_args) + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + assert ( + args.graph_name is not None + ), "Specify the name to save graph as with --graph_name" + + _, datastore, datastore_boundary = load_config_and_datastores( + config_path=args.config_path + ) + + create_kwargs = { + "mesh_node_distance": args.mesh_node_distance, + } + + if args.archetype != "keisler": + # Add additional multi-level kwargs + create_kwargs.update( + { + "level_refinement_factor": args.level_refinement_factor, + "max_num_levels": args.max_num_levels, + } + ) + + return build_graph_from_archetype( + datastore=datastore, + datastore_boundary=datastore_boundary, + graph_name=args.graph_name, + archetype=args.archetype, + **create_kwargs, + ) + + +def _build_wmg_graph( + datastore, + datastore_boundary, + graph_build_func, + kwargs, + graph_name, + dir_save_path=None, +): + """ + Build a graph using WMG in a way that's compatible with neural-lam. + Given datastores are used for coordinates and decode masking. + The given graph building function from WMG should be used, with kwargs. + + Parameters + ---------- + datastore : BaseDatastore + Datastore representing interior region of grid + datastore_boundary : BaseDatastore or None + Datastore representing boundary region, or None if no boundary forcing + graph_build_func + Function from WMG to use to build graph + kwargs : dict + Keyword arguments to feed to graph_build_func. Should not include + coords, coords_crs, graph_crs, return_components or decode_mask, as + these are here derived in a consistent way from the datastores. + graph_name : str + Name to save the graph as. + dir_save_path : str or None + Path to directory where graph should be saved, in directory graph_name. + If None, save in "graphs" directory in the root directory of datastore. + """ + + for derived_kwarg in ( + "coords", + "coords_crs", + "graph_crs", + "return_components", + "decode_mask", + ): + assert derived_kwarg not in kwargs, ( + f"Argument {derived_kwarg} should not be manually given when " + "building rectangular graph." + ) + + # Load grid positions + coords = utils.get_stacked_lat_lons(datastore, datastore_boundary) + # (num_nodes_full, 2) + # Project using crs from datastore for graph building + coords_crs = ccrs.PlateCarree() + graph_crs = datastore.coords_projection + + if datastore_boundary is None: + # No mask + decode_mask = None + else: + # Construct mask to decode only to interior + num_interior = datastore.num_grid_points + num_boundary = datastore_boundary.num_grid_points + decode_mask = np.concatenate( + ( + np.ones(num_interior, dtype=bool), + np.zeros(num_boundary, dtype=bool), + ), + axis=0, + ) + + # Set up all kwargs + create_kwargs = { + "coords": coords, + "decode_mask": decode_mask, + "graph_crs": graph_crs, + "coords_crs": coords_crs, + "return_components": True, + } + create_kwargs.update(kwargs) + + # Build graph + graph_comp = graph_build_func(**create_kwargs) + + print("Created graph:") + for name, subgraph in graph_comp.items(): + print(f"{name}: {subgraph}") + + # Need to know if hierarchical for saving + hierarchical = (graph_build_func == WMG_ARCHETYPES["hierarchical"]) or ( + "m2m_connectivity" in kwargs + and kwargs["m2m_connectivity"] == "hierarchical" + ) + + # Save graph + if dir_save_path is None: + graph_dir_path = os.path.join(datastore.root_path, "graphs", graph_name) + else: + graph_dir_path = os.path.join(dir_save_path, graph_name) + + os.makedirs(graph_dir_path, exist_ok=True) + for component, graph in graph_comp.items(): + # This seems like a bit of a hack, maybe better if saving in wmg + # was made consistent with nl + if component == "m2m": + if hierarchical: + # Split by direction + m2m_direction_comp = wmg.split_graph_by_edge_attribute( + graph, attr="direction" + ) + for direction, dir_graph in m2m_direction_comp.items(): + if direction == "same": + # Name just m2m to be consistent with non-hierarchical + wmg.save.to_pyg( + graph=dir_graph, + name="m2m", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + # up and down directions + wmg.save.to_pyg( + graph=dir_graph, + name=f"mesh_{direction}", + list_from_attribute="levels", + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + list_from_attribute="dummy", # Note: Needed to output list + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + edge_features=["len", "vdiff"], + output_directory=graph_dir_path, + ) + + +def build_graph_from_archetype( + datastore, + datastore_boundary, + graph_name, + archetype, + dir_save_path=None, + **kwargs, +): + """ + Function that builds graph using wmg archetype. + Uses archetype functions from wmg.create.archetype with kwargs being passed + on directly to those functions. + + Parameters + ---------- + datastore : BaseDatastore + Datastore representing interior region of grid + datastore_boundary : BaseDatastore or None + Datastore representing boundary region, or None if no boundary forcing + graph_name : str + Name to save the graph as. + archetype : str + Archetype to build. Must be one of "keisler", "graphcast" + or "hierarchical" + dir_save_path : str or None + Path to directory where graph should be saved, in directory graph_name. + If None, save in "graphs" directory in the root directory of datastore. + **kwargs + Keyword arguments that are passed on to + wmg.create.base.create_all_graph_components. See WMG for accepted + values for these. + """ + + assert archetype in WMG_ARCHETYPES, f"Unknown archetype: {archetype}" + archetype_create_func = WMG_ARCHETYPES[archetype] + + return _build_wmg_graph( + datastore=datastore, + datastore_boundary=datastore_boundary, + graph_build_func=archetype_create_func, + graph_name=graph_name, + dir_save_path=dir_save_path, + kwargs=kwargs, + ) + + +def build_graph( + datastore, datastore_boundary, graph_name, dir_save_path=None, **kwargs +): + """ + Function that can be used for more fine-grained control of graph + construction. Directly uses wmg.create.base.create_all_graph_components, + with kwargs being passed on directly to there. + + Parameters + ---------- + datastore : BaseDatastore + Datastore representing interior region of grid + datastore_boundary : BaseDatastore or None + Datastore representing boundary region, or None if no boundary forcing + graph_name : str + Name to save the graph as. + dir_save_path : str or None + Path to directory where graph should be saved, in directory graph_name. + If None, save in "graphs" directory in the root directory of datastore. + **kwargs + Keyword arguments that are passed on to + wmg.create.base.create_all_graph_components. See WMG for accepted + values for these. + """ + return _build_wmg_graph( + datastore=datastore, + datastore_boundary=datastore_boundary, + graph_build_func=wmg.create.base.create_all_graph_components, + graph_name=graph_name, + dir_save_path=dir_save_path, + kwargs=kwargs, + ) + + +if __name__ == "__main__": + main() diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..49440953 --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,207 @@ +# Standard library +import dataclasses +from pathlib import Path +from typing import Dict, Union + +# Third-party +import dataclass_wizard + +# Local +from .datastore import ( + DATASTORES, + MDPDatastore, + NpyFilesDatastoreMEPS, + init_datastore, +) + + +class DatastoreKindStr(str): + VALID_KINDS = DATASTORES.keys() + + def __new__(cls, value): + if value not in cls.VALID_KINDS: + raise ValueError(f"Invalid datastore kind: {value}") + return super().__new__(cls, value) + + +@dataclasses.dataclass +class DatastoreSelection: + """ + Configuration for selecting a datastore to use with neural-lam. + + Attributes + ---------- + kind : DatastoreKindStr + The kind of datastore to use, currently `mdp` or `npyfilesmeps` are + implemented. + config_path : str + The path to the configuration file for the selected datastore, this is + assumed to be relative to the configuration file for neural-lam. + """ + + kind: DatastoreKindStr + config_path: str + + +@dataclasses.dataclass +class ManualStateFeatureWeighting: + """ + Configuration for weighting the state features in the loss function where + the weights are manually specified. + + Attributes + ---------- + weights : Dict[str, float] + Manual weights for the state features. + """ + + weights: Dict[str, float] + + +@dataclasses.dataclass +class UniformFeatureWeighting: + """ + Configuration for weighting the state features in the loss function where + all state features are weighted equally. + """ + + pass + + +@dataclasses.dataclass +class OutputClamping: + """ + Configuration for clamping the output of the model. + + Attributes + ---------- + lower : Dict[str, float] + The minimum value to clamp each output feature to. + upper : Dict[str, float] + The maximum value to clamp each output feature to. + """ + + lower: Dict[str, float] = dataclasses.field(default_factory=dict) + upper: Dict[str, float] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class TrainingConfig: + """ + Configuration related to training neural-lam + + Attributes + ---------- + state_feature_weighting : Union[ManualStateFeatureWeighting, + UnformFeatureWeighting] + The method to use for weighting the state features in the loss + function. Defaults to uniform weighting (`UnformFeatureWeighting`, i.e. + all features are weighted equally). + """ + + state_feature_weighting: Union[ + ManualStateFeatureWeighting, UniformFeatureWeighting + ] = dataclasses.field(default_factory=UniformFeatureWeighting) + + output_clamping: OutputClamping = dataclasses.field( + default_factory=OutputClamping + ) + + +@dataclasses.dataclass +class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): + """ + Dataclass for Neural-LAM configuration. This class is used to load and + store the configuration for using Neural-LAM. + + Attributes + ---------- + datastore : DatastoreSelection + The configuration for the datastore to use. + datastore_boundary : Union[DatastoreSelection, None] + The configuration for the boundary datastore to use, if any. If None, + no boundary datastore is used. + training : TrainingConfig + The configuration for training the model. + """ + + datastore: DatastoreSelection + datastore_boundary: Union[DatastoreSelection, None] = None + training: TrainingConfig = dataclasses.field(default_factory=TrainingConfig) + + class _(dataclass_wizard.JSONWizard.Meta): + """ + Define the configuration class as a JSON wizard class. + + Together `tag_key` and `auto_assign_tags` enable that when a `Union` of + types are used for an attribute, the specific type to deserialize to + can be specified in the serialised data using the `tag_key` value. In + our case we call the tag key `__config_class__` to indicate to the + user that they should pick a dataclass describing configuration in + neural-lam. This Union-based selection allows us to support different + configuration attributes for different choices of methods for example + and is used when picking between different feature weighting methods in + the `TrainingConfig` class. `auto_assign_tags` is set to True to + automatically set that tag key (i.e. `__config_class__` in the config + file) should just be the class name of the dataclass to deserialize to. + """ + + tag_key = "__config_class__" + auto_assign_tags = True + # ensure that all parts of the loaded configuration match the + # dataclasses used + # TODO: this should be enabled once + # https://github.com/rnag/dataclass-wizard/issues/137 is fixed, but + # currently cannot be used together with `auto_assign_tags` due to a + # bug it seems + # raise_on_unknown_json_key = True + + +class InvalidConfigError(Exception): + pass + + +def load_config_and_datastores( + config_path: str, +) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]: + """ + Load the neural-lam configuration and the datastores specified in the + configuration. + + Parameters + ---------- + config_path : str + Path to the Neural-LAM configuration file. + + Returns + ------- + tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]] + The Neural-LAM configuration and the loaded datastores. + """ + try: + config = NeuralLAMConfig.from_yaml_file(config_path) + except dataclass_wizard.errors.UnknownJSONKey as ex: + raise InvalidConfigError( + "There was an error loading the configuration file at " + f"{config_path}. " + ) from ex + # datastore config is assumed to be relative to the config file + datastore_config_path = ( + Path(config_path).parent / config.datastore.config_path + ) + datastore = init_datastore( + datastore_kind=config.datastore.kind, config_path=datastore_config_path + ) + + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py new file mode 100644 index 00000000..40e683ac --- /dev/null +++ b/neural_lam/datastore/__init__.py @@ -0,0 +1,26 @@ +# Local +from .base import BaseDatastore # noqa +from .mdp import MDPDatastore # noqa +from .npyfilesmeps import NpyFilesDatastoreMEPS # noqa + +DATASTORE_CLASSES = [ + MDPDatastore, + NpyFilesDatastoreMEPS, +] + +DATASTORES = { + datastore.SHORT_NAME: datastore for datastore in DATASTORE_CLASSES +} + + +def init_datastore(datastore_kind, config_path): + DatastoreClass = DATASTORES.get(datastore_kind) + + if DatastoreClass is None: + raise NotImplementedError( + f"Datastore kind {datastore_kind} is not implemented" + ) + + datastore = DatastoreClass(config_path=config_path) + + return datastore diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py new file mode 100644 index 00000000..701dba83 --- /dev/null +++ b/neural_lam/datastore/base.py @@ -0,0 +1,617 @@ +# Standard library +import abc +import collections +import dataclasses +import functools +from functools import cached_property +from pathlib import Path +from typing import List, Union + +# Third-party +import cartopy.crs as ccrs +import numpy as np +import xarray as xr +from pandas.core.indexes.multi import MultiIndex + + +class BaseDatastore(abc.ABC): + """ + Base class for weather data used in the neural-lam package. A datastore + defines the interface for accessing weather data by providing methods to + access the data in a processed format that can be used for training and + evaluation of neural networks. + + NOTE: All methods return either primitive types, `numpy.ndarray`, + `xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor` + objects. Conversion to `pytorch.Tensor` objects should be done in the + `weather_dataset.WeatherDataset` class (which inherits from + `torch.utils.data.Dataset` and uses the datastore to access the data). + + # Forecast vs analysis data + If the datastore is used to represent forecast rather than analysis data, + then the `is_forecast` attribute should be set to True, and returned data + from `get_dataarray` is assumed to have `analysis_time` and `forecast_time` + dimensions (rather than just `time`). + + # Ensemble vs deterministic data + If the datastore is used to represent ensemble data, then the `is_ensemble` + attribute should be set to True, and returned data from `get_dataarray` is + assumed to have an `ensemble_member` dimension. + + # Grid index + All methods that return data specific to a grid point (like + `get_dataarray`) should have a single dimension named `grid_index` that + represents the spatial grid index of the data. The actual x, y coordinates + of the grid points should be stored in the `x` and `y` coordinates of the + dataarray or dataset with the `grid_index` dimension as the coordinate for + each of the `x` and `y` coordinates. + """ + + is_forecast: bool = False + + @property + @abc.abstractmethod + def root_path(self) -> Path: + """ + The root path to the datastore. It is relative to this that any derived + files (for example the graph components) are stored. + + Returns + ------- + pathlib.Path + The root path to the datastore. + + """ + pass + + @property + @abc.abstractmethod + def config(self) -> collections.abc.Mapping: + """The configuration of the datastore. + + Returns + ------- + collections.abc.Mapping + The configuration of the datastore, any dict like object can be + returned. + + """ + pass + + @property + @abc.abstractmethod + def step_length(self) -> int: + """The step length of the dataset in hours. + + Returns: + int: The step length in hours. + + """ + pass + + @abc.abstractmethod + def get_vars_units(self, category: str) -> List[str]: + """Get the units of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The units of the variables. + + """ + pass + + @abc.abstractmethod + def get_vars_names(self, category: str) -> List[str]: + """Get the names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The names of the variables. + + """ + pass + + @abc.abstractmethod + def get_vars_long_names(self, category: str) -> List[str]: + """Get the long names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The long names of the variables. + + """ + pass + + @abc.abstractmethod + def get_num_data_vars(self, category: str) -> int: + """Get the number of data variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + int + The number of data variables. + + """ + pass + + @abc.abstractmethod + def get_standardization_dataarray(self, category: str) -> xr.Dataset: + """ + Return the standardization (i.e. scaling to mean of 0.0 and standard + deviation of 1.0) dataarray for the given category. This should contain + a `{category}_mean` and `{category}_std` variable for each variable in + the category. For `category=="state"`, the dataarray should also + contain a `state_diff_mean` and `state_diff_std` variable for the one- + step differences of the state variables. The returned dataarray should + at least have dimensions of `({category}_feature)`, but can also + include for example `grid_index` (if the standardization is done per + grid point for example). + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + xr.Dataset + The standardization dataarray for the given category, with variables + for the mean and standard deviation of the variables (and + differences for state variables). + + """ + pass + + def _standardize_datarray( + self, da: xr.DataArray, category: str + ) -> xr.DataArray: + """ + Helper function to standardize a dataarray before returning it. + + Parameters + ---------- + da: xr.DataArray + The dataarray to standardize + category : str + The category of the dataarray (state/forcing/static), to load + standardization statistics for. + + Returns + ------- + xr.Dataarray + The standardized dataarray + """ + + standard_da = self.get_standardization_dataarray(category=category) + + mean = standard_da[f"{category}_mean"] + std = standard_da[f"{category}_std"] + + return (da - mean) / std + + @abc.abstractmethod + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> Union[xr.DataArray, None]: + """ + Return the processed data (as a single `xr.DataArray`) for the given + category of data and test/train/val-split that covers all the data (in + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. + + The returned dataarray is expected to at minimum have dimensions of + `(grid_index, {category}_feature)` so that any spatial dimensions have + been stacked into a single dimension and all variables and levels have + been stacked into a single feature dimension named by the `category` of + data being loaded. + + For categories of data that have a time dimension (i.e. not static + data), the dataarray is expected additionally have `(analysis_time, + elapsed_forecast_duration)` dimensions if `is_forecast` is True, or + `(time)` if `is_forecast` is False. + + If the data is ensemble data, the dataarray is expected to have an + additional `ensemble_member` dimension. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + split : str + The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized + + Returns + ------- + xr.DataArray or None + The xarray DataArray object with processed dataset. + + """ + pass + + @abc.abstractmethod + def get_xy(self, category: str) -> np.ndarray: + """ + Return the x, y coordinates of the dataset as a numpy arrays for a + given category of data. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + np.ndarray + The x, y coordinates of the dataset with shape `[n_grid_points, 2]`. + """ + pass + + @property + @abc.abstractmethod + def coords_projection(self) -> ccrs.Projection: + """Return the projection object for the coordinates. + + The projection object is used to plot the coordinates on a map. + + Returns + ------- + cartopy.crs.Projection: + The projection object. + + """ + pass + + @functools.lru_cache + def get_lat_lon(self, category: str) -> np.ndarray: + """ + Return the longitude, latitude coordinates of the dataset as numpy + array for a given category of data. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + np.ndarray + The longitude, latitude coordinates of the dataset + with shape `[n_grid_points, 2]`. + """ + xy = self.get_xy(category=category) + + transformed_points = ccrs.PlateCarree().transform_points( + self.coords_projection, xy[:, 0], xy[:, 1] + ) + return transformed_points[:, :2] # Remove z-dim + + @functools.lru_cache + def get_xy_extent(self, category: str) -> List[float]: + """ + Return the extent of the x, y coordinates for a given category of data. + The extent should be returned as a list of 4 floats with `[xmin, xmax, + ymin, ymax]` which can then be used to set the extent of a plot. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[float] + The extent of the x, y coordinates. + + """ + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] + return [float(v) for v in extent] + + @property + @abc.abstractmethod + def num_grid_points(self) -> int: + """Return the number of grid points in the dataset. + + Returns + ------- + int + The number of grid points in the dataset. + + """ + pass + + @property + def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset. + + Returns + ------- + int + The number of ensemble members in the dataset (default is 1 - + not an ensemble). + + """ + return 1 + + @property + def is_ensemble(self) -> bool: + """Return whether the dataset represents ensemble data. + + Returns + ------- + bool + True if the dataset represents ensemble data, False otherwise. + + """ + return self.num_ensemble_members > 1 + + @cached_property + @abc.abstractmethod + def state_feature_weights_values(self) -> List[float]: + """ + Return the weights for each state feature as a list of floats. The + weights are defined by the user in a config file for the datastore. + + Implementations of this method must assert that there is one weight for + each state feature in the datastore. The weights can be used to scale + the loss function for each state variable (e.g. via the standard + deviation of the 1-step differences of the state variables). + + Returns: + List[float]: The weights for each state feature. + """ + pass + + @functools.lru_cache + def expected_dim_order(self, category: str = None) -> tuple[str]: + """ + Return the expected dimension order for the dataarray or dataset + returned by `get_dataarray` for the given category of data. The + dimension order is the order of the dimensions in the dataarray or + dataset, and is used to check that the data is in the expected format. + + This is necessary so that when stacking and unstacking the spatial grid + we can ensure that the dimension order is the same as what is returned + from `get_dataarray`. And also ensures that downstream uses of a + datastore (e.g. WeatherDataset) sees the data in a common structure. + + If the category is None, then the it assumed that data only represents + a 1D scalar field varying with grid-index. + + The order is constructed to match the order in `pytorch.Tensor` objects + that will be constructed from the data so that the last two dimensions + are always the grid-index and feature dimensions (i.e. the order is + `[..., grid_index, {category}_feature]`), with any time-related and + ensemble-number dimension(s) coming before these two. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[str] + The expected dimension order for the dataarray or dataset. + + """ + dim_order = [] + + if category is not None: + if category != "static": + # static data does not vary in time + if self.is_forecast: + dim_order.extend( + ["analysis_time", "elapsed_forecast_duration"] + ) + elif not self.is_forecast: + dim_order.append("time") + + if self.is_ensemble and category == "state": + # XXX: for now we only assume ensemble data for state variables + dim_order.append("ensemble_member") + + dim_order.append("grid_index") + + if category is not None: + dim_order.append(f"{category}_feature") + + return tuple(dim_order) + + +@dataclasses.dataclass +class CartesianGridShape: + """Dataclass to store the shape of a grid.""" + + x: int + y: int + + +class BaseRegularGridDatastore(BaseDatastore): + """ + Base class for weather data stored on a regular grid (like a chess-board, + as opposed to a irregular grid where each cell cannot be indexed by just + two integers, see https://en.wikipedia.org/wiki/Regular_grid). In addition + to the methods and attributes required for weather data in general (see + `BaseDatastore`) for regular-gridded source data each `grid_index` + coordinate value is assumed to be associated with `x` and `y`-values that + allow the processed data-arrays can be reshaped back into into 2D + xy-gridded arrays. + + The following methods and attributes must be implemented for datastore that + represents regular-gridded data: + - `grid_shape_state` (property): 2D shape of the grid for the state + variables. + - `get_xy` (method): Return the x, y coordinates of the dataset, with the + option to not stack the coordinates (so that they are returned as a 2D + grid). + + The operation of going from (x,y)-indexed regular grid + to `grid_index`-indexed data-array is called "stacking" and the reverse + operation is called "unstacking". This class provides methods to stack and + unstack the spatial grid coordinates of the data-arrays (called + `stack_grid_coords` and `unstack_grid_coords` respectively). + """ + + CARTESIAN_COORDS = ["x", "y"] + + @cached_property + @abc.abstractmethod + def grid_shape_state(self) -> CartesianGridShape: + """The shape of the grid for the state variables. + + Returns + ------- + CartesianGridShape: + The shape of the grid for the state variables, which has `x` and + `y` attributes. + + """ + pass + + @abc.abstractmethod + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: + """Return the x, y coordinates of the dataset. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + stacked : bool + Whether to stack the x, y coordinates. The parameter `stacked` has + been introduced in this class. Parent class `BaseDatastore` has the + same method signature but without the `stacked` parameter. Defaults + to `True` to match the behaviour of `BaseDatastore.get_xy()` which + always returns the coordinates stacked. + + Returns + ------- + np.ndarray + The x, y coordinates of the dataset, returned differently based on + the value of `stacked`: - `stacked==True`: shape `(n_grid_points, + 2)` where + n_grid_points=N_x*N_y. + - `stacked==False`: shape `(N_x, N_y, 2)` + """ + pass + + def unstack_grid_coords( + self, da_or_ds: Union[xr.DataArray, xr.Dataset] + ) -> Union[xr.DataArray, xr.Dataset]: + """ + Unstack the spatial grid coordinates from `grid_index` into separate `x` + and `y` dimensions to create a 2D grid. Only performs unstacking if the + data is currently stacked (has grid_index dimension). + + Parameters + ---------- + da_or_ds : xr.DataArray or xr.Dataset + The dataarray or dataset to unstack the grid coordinates of. + + Returns + ------- + xr.DataArray or xr.Dataset + The dataarray or dataset with the grid coordinates unstacked. + """ + # Return original data if already unstacked (no grid_index dimension) + if "grid_index" not in da_or_ds.dims: + return da_or_ds + + # Check whether `grid_index` is a multi-index + if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex): + da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS) + + da_or_ds_unstacked = da_or_ds.unstack("grid_index") + + # Ensure that the x, y dimensions are in the correct order + dims = da_or_ds_unstacked.dims + xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS] + + if xy_dim_order != self.CARTESIAN_COORDS: + da_or_ds_unstacked = da_or_ds_unstacked.transpose("x", "y") + + return da_or_ds_unstacked + + def stack_grid_coords( + self, da_or_ds: Union[xr.DataArray, xr.Dataset] + ) -> Union[xr.DataArray, xr.Dataset]: + """ + Stack the spatial grid coordinates (x and y) into a single `grid_index` + dimension. Only performs stacking if the data is currently unstacked + (has x and y dimensions). + + Parameters + ---------- + da_or_ds : xr.DataArray or xr.Dataset + The dataarray or dataset to stack the grid coordinates of. + + Returns + ------- + xr.DataArray or xr.Dataset + The dataarray or dataset with the grid coordinates stacked. + """ + # Return original data if already stacked (has grid_index dimension) + if "grid_index" in da_or_ds.dims: + return da_or_ds + + da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) + + # infer what category of data the array represents by finding the + # dimension named in the format `{category}_feature` + category = None + for dim in da_or_ds_stacked.dims: + if dim.endswith("_feature"): + if category is not None: + raise ValueError( + "Multiple dimensions ending with '_feature' found in " + f"dataarray: {da_or_ds_stacked}. Cannot infer category." + ) + category = dim.split("_")[0] + + dim_order = self.expected_dim_order(category=category) + + return da_or_ds_stacked.transpose(*dim_order) + + @property + @functools.lru_cache + def num_grid_points(self) -> int: + """Return the number of grid points in the dataset. + + Returns + ------- + int + The number of grid points in the dataset. + + """ + return self.grid_shape_state.x * self.grid_shape_state.y diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py new file mode 100644 index 00000000..9582c558 --- /dev/null +++ b/neural_lam/datastore/mdp.py @@ -0,0 +1,486 @@ +# Standard library +import copy +import functools +import warnings +from functools import cached_property +from pathlib import Path +from typing import List, Union + +# Third-party +import cartopy.crs as ccrs +import mllam_data_prep as mdp +import numpy as np +import xarray as xr +from loguru import logger + +# Local +from .base import BaseRegularGridDatastore, CartesianGridShape + + +class MDPDatastore(BaseRegularGridDatastore): + """ + Datastore class for datasets made with the mllam_data_prep library + (https://github.com/mllam/mllam-data-prep). This class wraps the + `mllam_data_prep` library to do the necessary transforms to create the + different categories (state/forcing/static) of data, with the actual + transform to do being specified in the configuration file. + """ + + SHORT_NAME = "mdp" + + def __init__(self, config_path, reuse_existing=True): + """ + Construct a new MDPDatastore from the configuration file at + `config_path`. If `reuse_existing` is True, the dataset is loaded + from a zarr file if it exists (unless the config has been modified + since the zarr was created), otherwise it is created from the + configuration file. + + Parameters + ---------- + config_path : str + The path to the configuration file, this will be fed to the + `mllam_data_prep.Config.from_yaml_file` method to then call + `mllam_data_prep.create_dataset` to create the dataset. + reuse_existing : bool + Whether to reuse an existing dataset zarr file if it exists and its + creation date is newer than the configuration file. + + """ + self._config_path = Path(config_path) + self._root_path = self._config_path.parent + self._config = mdp.Config.from_yaml_file(self._config_path) + fp_ds = self._root_path / self._config_path.name.replace( + ".yaml", ".zarr" + ) + + self._ds = None + if reuse_existing and fp_ds.exists(): + # check that the zarr directory is newer than the config file + if fp_ds.stat().st_mtime < self._config_path.stat().st_mtime: + logger.warning( + "Config file has been modified since zarr was created. " + f"The old zarr archive (in {fp_ds}) will be used." + "To generate new zarr-archive, move the old one first." + ) + self._ds = xr.open_zarr(fp_ds, consolidated=True) + + if self._ds is None: + self._ds = mdp.create_dataset(config=self._config) + self._ds.to_zarr(fp_ds) + + print("The loaded datastore contains the following features:") + for category in ["state", "forcing", "static"]: + if len(self.get_vars_names(category)) > 0: + var_names = self.get_vars_names(category) + print(f" {category:<8s}: {' '.join(var_names)}") + + # check that all three train/val/test splits are available + required_splits = ["train", "val", "test"] + available_splits = list(self._ds.splits.split_name.values) + if not all(split in available_splits for split in required_splits): + raise ValueError( + f"Missing required splits: {required_splits} in available " + f"splits: {available_splits}" + ) + + print("With the following splits (over time):") + for split in required_splits: + da_split = self._ds.splits.sel(split_name=split) + if "grid_index" in da_split.coords: + da_split = da_split.isel(grid_index=0) + da_split_start = da_split.sel(split_part="start").load().item() + da_split_end = da_split.sel(split_part="end").load().item() + print(f" {split:<8s}: {da_split_start} to {da_split_end}") + + # find out the dimension order for the stacking to grid-index + dim_order = None + for input_dataset in self._config.inputs.values(): + dim_order_ = input_dataset.dim_mapping["grid_index"].dims + if dim_order is None: + dim_order = dim_order_ + else: + assert ( + dim_order == dim_order_ + ), "all inputs must have the same dimension order" + + self.CARTESIAN_COORDS = dim_order + + @property + def root_path(self) -> Path: + """The root path of the dataset. + + Returns + ------- + Path + The root path of the dataset. + + """ + return self._root_path + + @property + def config(self) -> mdp.Config: + """The configuration of the dataset. + + Returns + ------- + mdp.Config + The configuration of the dataset. + + """ + return self._config + + @property + def step_length(self) -> int: + """The length of the time steps in hours. + + Returns + ------- + int + The length of the time steps in hours. + + """ + da_dt = self._ds["time"].diff("time") + return (da_dt.dt.seconds[0] // 3600).item() + + def get_vars_units(self, category: str) -> List[str]: + """Return the units of the variables in the given category. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[str] + The units of the variables in the given category. + + """ + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] + return self._ds[f"{category}_feature_units"].values.tolist() + + def get_vars_names(self, category: str) -> List[str]: + """Return the names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[str] + The names of the variables in the given category. + + """ + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] + return self._ds[f"{category}_feature"].values.tolist() + + def get_vars_long_names(self, category: str) -> List[str]: + """ + Return the long names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[str] + The long names of the variables in the given category. + + """ + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] + return self._ds[f"{category}_feature_long_name"].values.tolist() + + def get_num_data_vars(self, category: str) -> int: + """Return the number of variables in the given category. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + int + The number of variables in the given category. + + """ + return len(self.get_vars_names(category)) + + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> Union[xr.DataArray, None]: + """ + Return the processed data (as a single `xr.DataArray`) for the given + category of data and test/train/val-split that covers all the data (in + space and time) of a given category (state/forcing/static). The method + will return `None` if the category is not found in the datastore. + + The returned dataarray will at minimum have dimensions of `(grid_index, + {category}_feature)` so that any spatial dimensions have been stacked + into a single dimension and all variables and levels have been stacked + into a single feature dimension named by the `category` of data being + loaded. + + For categories of data that have a time dimension (i.e. not static + data), the dataarray will additionally have `(analysis_time, + elapsed_forecast_duration)` dimensions if `is_forecast` is True, or + `(time)` if `is_forecast` is False. + + If the data is ensemble data, the dataarray will have an additional + `ensemble_member` dimension. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + split : str + The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized + + Returns + ------- + xr.DataArray or None + The xarray DataArray object with processed dataset. + + """ + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return None + + da_category = self._ds[category] + + # set multi-index for grid-index + da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS) + + if "time" in da_category.dims: + da_split = self._ds.splits.sel(split_name=split) + if "grid_index" in da_split.coords: + da_split = da_split.isel(grid_index=0) + t_start = da_split.sel(split_part="start").load().item() + t_end = da_split.sel(split_part="end").load().item() + da_category = da_category.sel(time=slice(t_start, t_end)) + + dim_order = self.expected_dim_order(category=category) + da_category = da_category.transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da_category, category=category) + + return da_category + + def get_standardization_dataarray(self, category: str) -> xr.Dataset: + """ + Return the standardization dataarray for the given category. This + should contain a `{category}_mean` and `{category}_std` variable for + each variable in the category. For `category=="state"`, the dataarray + should also contain a `state_diff_mean` and `state_diff_std` variable + for the one- step differences of the state variables. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + xr.Dataset + The standardization dataarray for the given category, with + variables for the mean and standard deviation of the variables (and + differences for state variables). + + """ + ops = ["mean", "std"] + split = "train" + stats_variables = { + f"{category}__{split}__{op}": f"{category}_{op}" for op in ops + } + if category == "state": + stats_variables.update( + {f"state__{split}__diff_{op}": f"state_diff_{op}" for op in ops} + ) + + ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) + if "grid_index" in ds_stats.coords: + ds_stats = ds_stats.isel(grid_index=0) + return ds_stats + + @property + def coords_projection(self) -> ccrs.Projection: + """ + Return the projection of the coordinates. + + NOTE: currently this expects the projection information to be in the + `extra` section of the configuration file, with a `projection` key + containing a `class_name` and `kwargs` for constructing the + `cartopy.crs.Projection` object. This is a temporary solution until + the projection information can be parsed in the produced dataset + itself. `mllam-data-prep` ignores the contents of the `extra` section + of the config file which is why we need to check that the necessary + parts are there. + + Returns + ------- + ccrs.Projection + The projection of the coordinates. + + """ + if "projection" not in self._config.extra: + raise ValueError( + "projection information not found in the configuration file " + f"({self._config_path}). Please add the projection information" + "to the `extra` section of the config, by adding a " + "`projection` key with the class name and kwargs of the " + "projection." + ) + + projection_info = self._config.extra["projection"] + if "class_name" not in projection_info: + raise ValueError( + "class_name not found in the projection information. Please " + "add the class name of the projection to the `projection` key " + "in the `extra` section of the config." + ) + if "kwargs" not in projection_info: + raise ValueError( + "kwargs not found in the projection information. Please add " + "the keyword arguments of the projection to the `projection` " + "key in the `extra` section of the config." + ) + + class_name = projection_info["class_name"] + ProjectionClass = getattr(ccrs, class_name) + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) + + globe_kwargs = kwargs.pop("globe", {}) + if len(globe_kwargs) > 0: + kwargs["globe"] = ccrs.Globe(**globe_kwargs) + + return ProjectionClass(**kwargs) + + @cached_property + def grid_shape_state(self): + """The shape of the cartesian grid for the state variables. + + Returns + ------- + CartesianGridShape + The shape of the cartesian grid for the state variables. + + """ + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + da_grid_reference = self.unstack_grid_coords(self._ds["forcing"]) + else: + da_grid_reference = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = da_grid_reference.x, da_grid_reference.y + assert da_x.ndim == da_y.ndim == 1 + return CartesianGridShape(x=da_x.size, y=da_y.size) + + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: + """Return the x, y coordinates of the dataset. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + stacked : bool + Whether to stack the x, y coordinates. + + Returns + ------- + np.ndarray + The x, y coordinates of the dataset, returned differently based on + the value of `stacked`: + - `stacked==True`: shape `(n_grid_points, 2)` where + n_grid_points=N_x*N_y. + - `stacked==False`: shape `(N_x, N_y, 2)` + + """ + # assume variables are stored in dimensions [grid_index, ...] + ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category]) + + da_xs = ds_category.x + da_ys = ds_category.y + + assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D" + + da_x, da_y = xr.broadcast(da_xs, da_ys) + da_xy = xr.concat([da_x, da_y], dim="grid_coord") + + if stacked: + da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose( + "grid_index", + "grid_coord", + ) + else: + dims = [ + "x", + "y", + "grid_coord", + ] + da_xy = da_xy.transpose(*dims) + + return da_xy.values + + @functools.lru_cache + def get_lat_lon(self, category: str) -> np.ndarray: + """ + Return the longitude, latitude coordinates of the dataset as numpy + array for a given category of data. + Override in MDP to use lat/lons directly from xr.Dataset, if available. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + np.ndarray + The longitude, latitude coordinates of the dataset + with shape `[n_grid_points, 2]`. + """ + # Check first if lat/lon saved in ds + lookup_ds = self._ds + if "latitude" in lookup_ds.coords and "longitude" in lookup_ds.coords: + lon = lookup_ds.longitude + lat = lookup_ds.latitude + elif "lat" in lookup_ds.coords and "lon" in lookup_ds.coords: + lon = lookup_ds.lon + lat = lookup_ds.lat + else: + # Not saved, use method from BaseDatastore to derive from x/y + return super().get_lat_lon(category) + + coords = np.stack((lon.values, lat.values), axis=1) + return coords + + @property + def num_grid_points(self) -> int: + """Return the number of grid points in the dataset. + + Returns + ------- + int + The number of grid points in the dataset. + + """ + return len(self._ds.grid_index) diff --git a/neural_lam/datastore/npyfilesmeps/__init__.py b/neural_lam/datastore/npyfilesmeps/__init__.py new file mode 100644 index 00000000..397a5075 --- /dev/null +++ b/neural_lam/datastore/npyfilesmeps/__init__.py @@ -0,0 +1,2 @@ +# Local +from .store import NpyFilesDatastoreMEPS # noqa diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py new file mode 100644 index 00000000..1f1c6943 --- /dev/null +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -0,0 +1,417 @@ +# Standard library +import os +import subprocess +from argparse import ArgumentParser +from pathlib import Path + +# Third-party +import torch +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +# First-party +from neural_lam import WeatherDataset +from neural_lam.datastore import init_datastore + + +class PaddedWeatherDataset(torch.utils.data.Dataset): + def __init__(self, base_dataset, world_size, batch_size): + super().__init__() + self.base_dataset = base_dataset + self.world_size = world_size + self.batch_size = batch_size + self.total_samples = len(base_dataset) + self.padded_samples = ( + (self.world_size * self.batch_size) - self.total_samples + ) % self.world_size + self.original_indices = list(range(len(base_dataset))) + self.padded_indices = list( + range(self.total_samples, self.total_samples + self.padded_samples) + ) + + def __getitem__(self, idx): + return self.base_dataset[ + self.original_indices[-1] + if idx >= self.total_samples + else idx % len(self.base_dataset) + ] + + def __len__(self): + return self.total_samples + self.padded_samples + + def get_original_indices(self): + return self.original_indices + + def get_original_window_indices(self, step_length): + return [ + i // step_length + for i in range(len(self.original_indices) * step_length) + ] + + +def get_rank(): + return int(os.environ.get("SLURM_PROCID", 0)) + + +def get_world_size(): + return int(os.environ.get("SLURM_NTASKS", 1)) + + +def setup(rank, world_size): # pylint: disable=redefined-outer-name + """Initialize the distributed group.""" + if "SLURM_JOB_NODELIST" in os.environ: + master_node = ( + subprocess.check_output( + "scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1", + shell=True, + ) + .strip() + .decode("utf-8") + ) + else: + print( + "\033[91mCareful, you are running this script with --distributed " + "without any scheduler. In most cases this will result in slower " + "execution and the --distributed flag should be removed.\033[0m" + ) + master_node = "localhost" + os.environ["MASTER_ADDR"] = master_node + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group( + "nccl" if torch.cuda.is_available() else "gloo", + rank=rank, + world_size=world_size, + ) + if rank == 0: + print( + f"Initialized {dist.get_backend()} " + f"process group with world size {world_size}." + ) + + +def save_stats( + static_dir_path, means, squares, flux_means, flux_squares, filename_prefix +): + means = ( + torch.stack(means) if len(means) > 1 else means[0] + ) # (N_batch, d_features,) + squares = ( + torch.stack(squares) if len(squares) > 1 else squares[0] + ) # (N_batch, d_features,) + mean = torch.mean(means, dim=0) # (d_features,) + second_moment = torch.mean(squares, dim=0) # (d_features,) + std = torch.sqrt(second_moment - mean**2) # (d_features,) + print( + f"Saving {filename_prefix} mean and std.-dev. to " + f"{filename_prefix}_mean.pt and {filename_prefix}_std.pt" + ) + torch.save( + mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt") + ) + torch.save( + std.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_std.pt") + ) + + if len(flux_means) == 0: + return + flux_means = ( + torch.stack(flux_means) if len(flux_means) > 1 else flux_means[0] + ) # (N_batch,) + flux_squares = ( + torch.stack(flux_squares) if len(flux_squares) > 1 else flux_squares[0] + ) # (N_batch,) + flux_mean = torch.mean(flux_means) # (,) + flux_second_moment = torch.mean(flux_squares) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + print("Saving flux mean and std.-dev. to flux_stats.pt") + torch.save( + torch.stack((flux_mean, flux_std)).cpu(), + os.path.join(static_dir_path, "flux_stats.pt"), + ) + + +def main( + datastore_config_path, batch_size, step_length, n_workers, distributed +): + """ + Pre-compute parameter weights to be used in loss function + + Arguments + --------- + datastore_config_path : str + Path to datastore config file + batch_size : int + Batch size when iterating over the dataset + step_length : int + Step length in hours to consider single time step + n_workers : int + Number of workers in data loader + distributed : bool + Run the script in distributed + """ + + rank = get_rank() + world_size = get_world_size() + datastore = init_datastore( + datastore_kind="npyfilesmeps", config_path=datastore_config_path + ) + + static_dir_path = Path(datastore_config_path).parent / "static" + os.makedirs(static_dir_path, exist_ok=True) + + if distributed: + setup(rank, world_size) + device = torch.device( + f"cuda:{rank}" if torch.cuda.is_available() else "cpu" + ) + torch.cuda.set_device(device) if torch.cuda.is_available() else None + + # Setting this to the original value of the Oskarsson et al. paper (2023) + # 65 forecast steps - 2 initial steps = 63 + ar_steps = 63 + ds = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=ar_steps, + standardize=False, + num_past_forcing_steps=0, + num_future_forcing_steps=0, + ) + if distributed: + ds = PaddedWeatherDataset( + ds, + world_size, + batch_size, + ) + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler = None + loader = torch.utils.data.DataLoader( + ds, + batch_size, + shuffle=False, + num_workers=n_workers, + sampler=sampler, + ) + + if rank == 0: + print("Computing mean and std.-dev. for parameters...") + means, squares, flux_means, flux_squares = [], [], [], [] + + for init_batch, target_batch, forcing_batch, _, _ in tqdm(loader): + if distributed: + init_batch, target_batch, forcing_batch = ( + init_batch.to(device), + target_batch.to(device), + forcing_batch.to(device), + ) + # (N_batch, N_t, N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) + # Flux at 1st windowed position is index 0 in forcing + flux_batch = forcing_batch[:, :, :, 0] + # (N_batch, d_features,) + means.append(torch.mean(batch, dim=(1, 2)).cpu()) + squares.append( + torch.mean(batch**2, dim=(1, 2)).cpu() + ) # (N_batch, d_features,) + flux_means.append(torch.mean(flux_batch).cpu()) # (,) + flux_squares.append(torch.mean(flux_batch**2).cpu()) # (,) + + if distributed and world_size > 1: + means_gathered, squares_gathered = [None] * world_size, [ + None + ] * world_size + flux_means_gathered, flux_squares_gathered = ( + [None] * world_size, + [None] * world_size, + ) + dist.all_gather_object(means_gathered, torch.cat(means, dim=0)) + dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0)) + dist.all_gather_object(flux_means_gathered, flux_means) + dist.all_gather_object(flux_squares_gathered, flux_squares) + + if rank == 0: + means_gathered, squares_gathered = ( + torch.cat(means_gathered, dim=0), + torch.cat(squares_gathered, dim=0), + ) + flux_means_gathered, flux_squares_gathered = ( + torch.tensor(flux_means_gathered), + torch.tensor(flux_squares_gathered), + ) + + original_indices = ds.get_original_indices() + means, squares = ( + [means_gathered[i] for i in original_indices], + [squares_gathered[i] for i in original_indices], + ) + flux_means, flux_squares = ( + [flux_means_gathered[i] for i in original_indices], + [flux_squares_gathered[i] for i in original_indices], + ) + else: + means = [torch.cat(means, dim=0)] # (N_batch, d_features,) + squares = [torch.cat(squares, dim=0)] # (N_batch, d_features,) + flux_means = [torch.tensor(flux_means)] # (N_batch,) + flux_squares = [torch.tensor(flux_squares)] # (N_batch,) + + if rank == 0: + save_stats( + static_dir_path, + means, + squares, + flux_means, + flux_squares, + "parameter", + ) + + if distributed: + dist.barrier() + + if rank == 0: + print("Computing mean and std.-dev. for one-step differences...") + ds_standard = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=ar_steps, + standardize=True, + num_past_forcing_steps=0, + num_future_forcing_steps=0, + ) # Re-load with standardization + if distributed: + ds_standard = PaddedWeatherDataset( + ds_standard, + world_size, + batch_size, + ) + sampler_standard = DistributedSampler( + ds_standard, num_replicas=world_size, rank=rank, shuffle=False + ) + else: + sampler_standard = None + loader_standard = torch.utils.data.DataLoader( + ds_standard, + batch_size, + shuffle=False, + num_workers=n_workers, + sampler=sampler_standard, + ) + used_subsample_len = (65 // step_length) * step_length + + diff_means, diff_squares = [], [] + + for init_batch, target_batch, _, _, _ in tqdm( + loader_standard, disable=rank != 0 + ): + if distributed: + init_batch, target_batch = init_batch.to(device), target_batch.to( + device + ) + # (N_batch, N_t', N_grid, d_features) + batch = torch.cat((init_batch, target_batch), dim=1) + # Note: batch contains only 1h-steps + stepped_batch = torch.cat( + [ + batch[:, ss_i:used_subsample_len:step_length] + for ss_i in range(step_length) + ], + dim=0, + ) + # (N_batch', N_t, N_grid, d_features), + # N_batch' = step_length*N_batch + batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] + # (N_batch', N_t-1, N_grid, d_features) + diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + diff_squares.append(torch.mean(batch_diffs**2, dim=(1, 2)).cpu()) + # (N_batch', d_features,) + + if distributed and world_size > 1: + dist.barrier() + diff_means_gathered, diff_squares_gathered = ( + [None] * world_size, + [None] * world_size, + ) + dist.all_gather_object( + diff_means_gathered, torch.cat(diff_means, dim=0) + ) + dist.all_gather_object( + diff_squares_gathered, torch.cat(diff_squares, dim=0) + ) + + if rank == 0: + diff_means_gathered, diff_squares_gathered = ( + torch.cat(diff_means_gathered, dim=0).view( + -1, *diff_means[0].shape + ), + torch.cat(diff_squares_gathered, dim=0).view( + -1, *diff_squares[0].shape + ), + ) + original_indices = ds_standard.get_original_window_indices( + step_length + ) + diff_means, diff_squares = ( + [diff_means_gathered[i] for i in original_indices], + [diff_squares_gathered[i] for i in original_indices], + ) + + diff_means = [torch.cat(diff_means, dim=0)] # (N_batch', d_features,) + diff_squares = [torch.cat(diff_squares, dim=0)] # (N_batch', d_features,) + + if rank == 0: + save_stats(static_dir_path, diff_means, diff_squares, [], [], "diff") + + if distributed: + dist.destroy_process_group() + + +def cli(): + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--datastore_config_path", + type=str, + help="Path to data config file", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size when iterating over the dataset", + ) + parser.add_argument( + "--step_length", + type=int, + default=3, + help="Step length in hours to consider single time step (default: 3)", + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) + parser.add_argument( + "--distributed", + action="store_true", + help="Run the script in distributed mode (default: False)", + ) + args = parser.parse_args() + distributed = bool(args.distributed) + + main( + datastore_config_path=args.datastore_config_path, + batch_size=args.batch_size, + step_length=args.step_length, + n_workers=args.n_workers, + distributed=distributed, + ) + + +if __name__ == "__main__": + cli() diff --git a/neural_lam/datastore/npyfilesmeps/config.py b/neural_lam/datastore/npyfilesmeps/config.py new file mode 100644 index 00000000..1a9d7295 --- /dev/null +++ b/neural_lam/datastore/npyfilesmeps/config.py @@ -0,0 +1,66 @@ +# Standard library +from dataclasses import dataclass, field +from typing import Any, Dict, List + +# Third-party +import dataclass_wizard + + +@dataclass +class Projection: + """Represents the projection information for a dataset, including the type + of projection and its parameters. Capable of creating a cartopy.crs + projection object. + + Attributes: + class_name: The class name of the projection, this should be a valid + cartopy.crs class. + kwargs: A dictionary of keyword arguments specific to the projection + type. + + """ + + class_name: str + kwargs: Dict[str, Any] + + +@dataclass +class Dataset: + """Contains information about the dataset, including variable names, units, + and descriptions. + + Attributes: + name: The name of the dataset. + var_names: A list of variable names in the dataset. + var_units: A list of units for each variable. + var_longnames: A list of long, descriptive names for each variable. + num_forcing_features: The number of forcing features in the dataset. + + """ + + name: str + var_names: List[str] + var_units: List[str] + var_longnames: List[str] + num_forcing_features: int + num_timesteps: int + step_length: int + num_ensemble_members: int + remove_state_features_with_index: List[int] = field(default_factory=list) + + +@dataclass +class NpyDatastoreConfig(dataclass_wizard.YAMLWizard): + """Configuration for loading and processing a dataset, including dataset + details, grid shape, and projection information. + + Attributes: + dataset: An instance of Dataset containing details about the dataset. + grid_shape_state: A list representing the shape of the grid state. + projection: An instance of Projection containing projection details. + + """ + + dataset: Dataset + grid_shape_state: List[int] + projection: Projection diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py new file mode 100644 index 00000000..22588a06 --- /dev/null +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -0,0 +1,786 @@ +""" +Numpy-files based datastore to support the MEPS example dataset introduced in +neural-lam v0.1.0. +""" + +# Standard library +import functools +import re +import warnings +from functools import cached_property +from pathlib import Path +from typing import List + +# Third-party +import cartopy.crs as ccrs +import dask +import dask.array +import dask.delayed +import numpy as np +import parse +import torch +import xarray as xr +from xarray.core.dataarray import DataArray + +# Local +from ..base import BaseRegularGridDatastore, CartesianGridShape +from .config import NpyDatastoreConfig + +STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy" +TOA_SW_DOWN_FLUX_FILENAME_FORMAT = ( + "nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy" +) +OPEN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy" + + +def _load_np(fp, add_feature_dim, feature_dim_mask=None): + arr = np.load(fp) + if add_feature_dim: + arr = arr[..., np.newaxis] + if feature_dim_mask is not None: + arr = arr[..., feature_dim_mask] + return arr + + +class NpyFilesDatastoreMEPS(BaseRegularGridDatastore): + __doc__ = f""" + Represents a dataset stored as numpy files on disk. The dataset is assumed + to be stored in a directory structure where each sample is stored in a + separate file. The file-name format is assumed to be + '{STATE_FILENAME_FORMAT}' + + The MEPS dataset is organised into three splits: train, val, and test. Each + split has a set of files which are: + + - `{STATE_FILENAME_FORMAT}`: + The state variables for a forecast started at `analysis_time` with + member id `member_id`. The dimensions of the array are + `[forecast_timestep, y, x, feature]`. + + - `{TOA_SW_DOWN_FLUX_FILENAME_FORMAT}`: + The top-of-atmosphere downwelling shortwave flux at `time`. The + dimensions of the array are `[forecast_timestep, y, x]`. + + - `{OPEN_WATER_FILENAME_FORMAT}`: + The open water fraction at `time`. The dimensions of the array are + `[y, x]`. + + + Folder structure: + + meps_example_reduced + ├── data_config.yaml + ├── samples + │ ├── test + │ │ ├── nwp_2022090100_mbr000.npy + │ │ ├── nwp_2022090100_mbr001.npy + │ │ ├── nwp_2022090112_mbr000.npy + │ │ ├── nwp_2022090112_mbr001.npy + │ │ ├── ... + │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090100.npy + │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090112.npy + │ │ ├── ... + │ │ ├── wtr_2022090100.npy + │ │ ├── wtr_2022090112.npy + │ │ └── ... + │ ├── train + │ │ ├── nwp_2022040100_mbr000.npy + │ │ ├── nwp_2022040100_mbr001.npy + │ │ ├── ... + │ │ ├── nwp_2022040112_mbr000.npy + │ │ ├── nwp_2022040112_mbr001.npy + │ │ ├── ... + │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy + │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy + │ │ ├── ... + │ │ ├── wtr_2022040100.npy + │ │ ├── wtr_2022040112.npy + │ │ └── ... + │ └── val + │ ├── nwp_2022060500_mbr000.npy + │ ├── nwp_2022060500_mbr001.npy + │ ├── ... + │ ├── nwp_2022060512_mbr000.npy + │ ├── nwp_2022060512_mbr001.npy + │ ├── ... + │ ├── nwp_toa_downwelling_shortwave_flux_2022060500.npy + │ ├── nwp_toa_downwelling_shortwave_flux_2022060512.npy + │ ├── ... + │ ├── wtr_2022060500.npy + │ ├── wtr_2022060512.npy + │ └── ... + └── static + ├── border_mask.npy + ├── diff_mean.pt + ├── diff_std.pt + ├── flux_stats.pt + ├── grid_features.pt + ├── nwp_xy.npy + ├── parameter_mean.pt + ├── parameter_std.pt + ├── parameter_weights.npy + └── surface_geopotential.npy + + For the MEPS dataset: + N_t' = 65 + N_t = 65//subsample_step (= 21 for 3h steps) + dim_y = 268 + dim_x = 238 + N_grid = 268x238 = 63784 + d_features = 17 (d_features' = 18) + d_forcing = 5 + + For the MEPS reduced dataset: + N_t' = 65 + N_t = 65//subsample_step (= 21 for 3h steps) + dim_y = 134 + dim_x = 119 + N_grid = 134x119 = 15946 + d_features = 8 + d_forcing = 1 + """ + SHORT_NAME = "npyfilesmeps" + + is_forecast = True + + def __init__( + self, + config_path, + ): + """ + Create a new NpyFilesDatastore using the configuration file at the + given path. The config file should be a YAML file and will be loaded + into an instance of the `NpyDatastoreConfig` dataclass. + + Internally, the datastore uses dask.delayed to load the data from the + numpy files, so that the data isn't actually loaded until it's needed. + + Parameters + ---------- + config_path : str + The path to the configuration file for the datastore. + + """ + self._config_path = Path(config_path) + self._root_path = self._config_path.parent + self._config = NpyDatastoreConfig.from_yaml_file(self._config_path) + + self._num_timesteps = self.config.dataset.num_timesteps + self._step_length = self.config.dataset.step_length + self._remove_state_features_with_index = ( + self.config.dataset.remove_state_features_with_index + ) + + @property + def root_path(self) -> Path: + """ + The root path of the datastore on disk. This is the directory relative + to which graphs and other files can be stored. + + Returns + ------- + Path + The root path of the datastore + + """ + return self._root_path + + @property + def config(self) -> NpyDatastoreConfig: + """The configuration for the datastore. + + Returns + ------- + NpyDatastoreConfig + The configuration for the datastore. + + """ + return self._config + + @property + def num_ensemble_members(self) -> int: + """Return the number of ensemble members in the dataset as defined in + the config file. + + Returns + ------- + int + The number of ensemble members in the dataset. + + """ + return self.config.dataset.num_ensemble_members + + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> DataArray: + """ + Get the data array for the given category and split of data. If the + category is 'state', the data array will be a concatenation of the data + arrays for all ensemble members. The data will be loaded as a dask + array, so that the data isn't actually loaded until it's needed. + + Parameters + ---------- + category : str + The category of the data to load. One of 'state', 'forcing', or + 'static'. + split : str + The dataset split to load the data for. One of 'train', 'val', or + 'test'. + standardize: bool + If the dataarray should be returned standardized + + Returns + ------- + xr.DataArray + The data array for the given category and split, with dimensions + per category: + state: `[elapsed_forecast_duration, analysis_time, grid_index, + feature, ensemble_member]` + forcing: `[elapsed_forecast_duration, analysis_time, grid_index, + feature]` + static: `[grid_index, feature]` + + """ + if category == "state": + das = [] + # for the state category, we need to load all ensemble members + for member in range(self.num_ensemble_members): + da_member = self._get_single_timeseries_dataarray( + features=self.get_vars_names(category="state"), + split=split, + member=member, + ) + das.append(da_member) + da = xr.concat(das, dim="ensemble_member") + + elif category == "forcing": + # the forcing features are in separate files, so we need to load + # them separately + features = ["toa_downwelling_shortwave_flux", "open_water_fraction"] + das = [ + self._get_single_timeseries_dataarray( + features=[feature], split=split + ) + for feature in features + ] + da = xr.concat(das, dim="feature") + + # add datetime forcing as a feature + # to do this we create a forecast time variable which has the + # dimensions of (analysis_time, elapsed_forecast_duration) with + # values that are the actual forecast time of each time step. By + # calling .chunk({"elapsed_forecast_duration": 1}) this time + # variable is turned into a dask array and so execution of the + # calculation is delayed until the feature values are actually + # used. + da_forecast_time = ( + da.analysis_time + da.elapsed_forecast_duration + ).chunk({"elapsed_forecast_duration": 1}) + da_datetime_forcing_features = self._calc_datetime_forcing_features( + da_time=da_forecast_time + ) + da = xr.concat([da, da_datetime_forcing_features], dim="feature") + + elif category == "static": + # the static features are collected in three files: + # - surface_geopotential + # - border_mask + # - x, y + das = [] + for features in [ + ["surface_geopotential"], + ["border_mask"], + ["x", "y"], + ]: + da = self._get_single_timeseries_dataarray( + features=features, split=split + ) + das.append(da) + da = xr.concat(das, dim="feature") + + else: + raise NotImplementedError(category) + + da = da.rename(dict(feature=f"{category}_feature")) + + # stack the [x, y] dimensions into a `grid_index` dimension + da = self.stack_grid_coords(da) + + # check that we have the right features + actual_features = da[f"{category}_feature"].values.tolist() + expected_features = self.get_vars_names(category=category) + if actual_features != expected_features: + raise ValueError( + f"Expected features {expected_features}, got {actual_features}" + ) + + dim_order = self.expected_dim_order(category=category) + da = da.transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da, category=category) + + return da + + def _get_single_timeseries_dataarray( + self, features: List[str], split: str, member: int = None + ) -> DataArray: + """ + Get the data array spanning the complete time series for a given set of + features and split of data. For state features the `member` argument + should be specified to select the ensemble member to load. The data + will be loaded using dask.delayed, so that the data isn't actually + loaded until it's needed. + + Parameters + ---------- + features : List[str] + The list of features to load the data for. For the 'state' + category, this should be the result of + `self.get_vars_names(category="state")`, for the 'forcing' category + this should be the list of forcing features to load, and for the + 'static' category this should be the list of static features to + load. + split : str + The dataset split to load the data for. One of 'train', 'val', or + 'test'. + member : int, optional + The ensemble member to load. Only applicable for the 'state' + category. + + Returns + ------- + xr.DataArray + The data array for the given category and split, with dimensions + `[elapsed_forecast_duration, analysis_time, grid_index, feature]` + for all categories of data + + """ + if ( + set(features).difference(self.get_vars_names(category="static")) + == set() + ): + assert split in ( + "train", + "val", + "test", + None, + ), "Unknown dataset split" + else: + assert split in ( + "train", + "val", + "test", + ), f"Unknown dataset split {split} for features {features}" + + if member is not None and features != self.get_vars_names( + category="state" + ): + raise ValueError( + "Member can only be specified for the 'state' category" + ) + + concat_axis = 0 + + file_params = {} + add_feature_dim = False + features_vary_with_analysis_time = True + feature_dim_mask = None + if features == self.get_vars_names(category="state"): + filename_format = STATE_FILENAME_FORMAT + file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] + # only select one member for now + file_params["member_id"] = member + fp_samples = self.root_path / "samples" / split + if self._remove_state_features_with_index: + n_to_drop = len(self._remove_state_features_with_index) + feature_dim_mask = np.ones( + len(features) + n_to_drop, dtype=bool + ) + feature_dim_mask[self._remove_state_features_with_index] = False + elif features == ["toa_downwelling_shortwave_flux"]: + filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT + file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] + add_feature_dim = True + fp_samples = self.root_path / "samples" / split + elif features == ["open_water_fraction"]: + filename_format = OPEN_WATER_FILENAME_FORMAT + file_dims = ["y", "x", "feature"] + add_feature_dim = True + fp_samples = self.root_path / "samples" / split + elif features == ["surface_geopotential"]: + filename_format = "surface_geopotential.npy" + file_dims = ["y", "x", "feature"] + add_feature_dim = True + features_vary_with_analysis_time = False + # XXX: surface_geopotential is the same for all splits, and so + # saved in static/ + fp_samples = self.root_path / "static" + elif features == ["border_mask"]: + filename_format = "border_mask.npy" + file_dims = ["y", "x", "feature"] + add_feature_dim = True + features_vary_with_analysis_time = False + # XXX: border_mask is the same for all splits, and so saved in + # static/ + fp_samples = self.root_path / "static" + elif features == ["x", "y"]: + filename_format = "nwp_xy.npy" + # NB: for x, y the feature dimension is the first one + file_dims = ["feature", "y", "x"] + features_vary_with_analysis_time = False + # XXX: x, y are the same for all splits, and so saved in static/ + fp_samples = self.root_path / "static" + else: + raise NotImplementedError( + f"Reading of variables set `{features}` not supported" + ) + + if features_vary_with_analysis_time: + dims = ["analysis_time"] + file_dims + else: + dims = file_dims + + coords = {} + arr_shape = [] + + xy = self.get_xy(category="state", stacked=False) + xs = xy[:, :, 0] + ys = xy[:, :, 1] + # Check if x-coordinates are constant along columns + assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" + # Check if y-coordinates are constant along rows + assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" + # Extract unique x and y coordinates + x = xs[:, 0] # Unique x-coordinates (changes along the first axis) + y = ys[0, :] # Unique y-coordinates (changes along the second axis) + for d in dims: + if d == "elapsed_forecast_duration": + coord_values = ( + self.step_length + * np.arange(self._num_timesteps) + * np.timedelta64(1, "h") + ) + elif d == "analysis_time": + coord_values = self._get_analysis_times( + split=split, member_id=member + ) + elif d == "y": + coord_values = y + elif d == "x": + coord_values = x + elif d == "feature": + coord_values = features + else: + raise NotImplementedError(f"Dimension {d} not supported") + + coords[d] = coord_values + if d != "analysis_time": + # analysis_time varies across the different files, but not + # within a single file + arr_shape.append(len(coord_values)) + + if features_vary_with_analysis_time: + filepaths = [ + fp_samples + / filename_format.format( + analysis_time=analysis_time, **file_params + ) + for analysis_time in coords["analysis_time"] + ] + else: + filepaths = [fp_samples / filename_format.format(**file_params)] + + # use dask.delayed to load the numpy files, so that loading isn't + # done until the data is actually needed + arrays = [ + dask.array.from_delayed( + dask.delayed(_load_np)( + fp=fp, + add_feature_dim=add_feature_dim, + feature_dim_mask=feature_dim_mask, + ), + shape=arr_shape, + dtype=np.float32, + ) + for fp in filepaths + ] + + # read a single timestep and check the shape + arr0 = arrays[0].compute() + if not list(arr0.shape) == arr_shape: + raise Exception( + f"Expected shape {arr_shape} for a single file, got " + f"{list(arr0.shape)}. Maybe the number of features given " + f"in the datastore config ({features}) is incorrect?" + ) + + if features_vary_with_analysis_time: + arr_all = dask.array.stack(arrays, axis=concat_axis) + else: + arr_all = arrays[0] + + da = xr.DataArray(arr_all, dims=dims, coords=coords) + + return da + + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: + """Get the analysis times for the given split by parsing the filenames + of all the files found for the given split. + + Parameters + ---------- + split : str + The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. + + Returns + ------- + List[dt.datetime] + The analysis times for the given split. + + """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 + pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) + + sample_dir = self.root_path / "samples" / split + sample_files = sample_dir.glob(pattern) + times = [] + for fp in sample_files: + name_parts = parse.parse(STATE_FILENAME_FORMAT, fp.name) + times.append(name_parts["analysis_time"]) + + if len(times) == 0: + raise ValueError( + f"No files found in {sample_dir} with pattern {pattern}" + ) + + return times + + def _calc_datetime_forcing_features(self, da_time: xr.DataArray): + da_hour_angle = da_time.dt.hour / 12 * np.pi + da_year_angle = da_time.dt.dayofyear / 365 * 2 * np.pi + + da_datetime_forcing = xr.concat( + ( + np.sin(da_hour_angle), + np.cos(da_hour_angle), + np.sin(da_year_angle), + np.cos(da_year_angle), + ), + dim="feature", + ) + da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1] + da_datetime_forcing["feature"] = [ + "sin_hour", + "cos_hour", + "sin_year", + "cos_year", + ] + + return da_datetime_forcing + + def get_vars_units(self, category: str) -> List[str]: + if category == "state": + return self.config.dataset.var_units + elif category == "forcing": + return [ + "W/m^2", + "1", + "1", + "1", + "1", + "1", + ] + elif category == "static": + return ["m^2/s^2", "1", "m", "m"] + else: + raise NotImplementedError(f"Category {category} not supported") + + def get_vars_names(self, category: str) -> List[str]: + if category == "state": + return self.config.dataset.var_names + elif category == "forcing": + # XXX: this really shouldn't be hard-coded here, this should be in + # the config + return [ + "toa_downwelling_shortwave_flux", + "open_water_fraction", + "sin_hour", + "cos_hour", + "sin_year", + "cos_year", + ] + elif category == "static": + return ["surface_geopotential", "border_mask", "x", "y"] + else: + raise NotImplementedError(f"Category {category} not supported") + + def get_vars_long_names(self, category: str) -> List[str]: + if category == "state": + return self.config.dataset.var_longnames + else: + # TODO: should we add these? + return self.get_vars_names(category=category) + + def get_num_data_vars(self, category: str) -> int: + return len(self.get_vars_names(category=category)) + + def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: + """Return the x, y coordinates of the dataset. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + stacked : bool + Whether to stack the x, y coordinates. + + Returns + ------- + np.ndarray + The x, y coordinates of the dataset (with x first then y second), + returned differently based on the value of `stacked`: + - `stacked==True`: shape `(n_grid_points, 2)` where + n_grid_points=N_x*N_y. + - `stacked==False`: shape `(N_x, N_y, 2)` + + """ + + # the array on disk has shape [2, N_y, N_x], where dimension 0 + # contains the [x,y] coordinate pairs for each grid point + arr = np.load(self.root_path / "static" / "nwp_xy.npy") + arr_shape = arr.shape + + assert arr_shape[0] == 2, "Expected 2D array" + grid_shape = self.grid_shape_state + assert arr_shape[1:] == (grid_shape.y, grid_shape.x), "Unexpected shape" + + arr = arr.transpose(2, 1, 0) + + if stacked: + return arr.reshape(-1, 2) + else: + return arr + + @property + def step_length(self) -> int: + """The length of each time step in hours. + + Returns + ------- + int + The length of each time step in hours. + + """ + return self._step_length + + @cached_property + def grid_shape_state(self) -> CartesianGridShape: + """The shape of the cartesian grid for the state variables. + + Returns + ------- + CartesianGridShape + The shape of the cartesian grid for the state variables. + + """ + ny, nx = self.config.grid_shape_state + return CartesianGridShape(x=nx, y=ny) + + def get_standardization_dataarray(self, category: str) -> xr.Dataset: + """Return the standardization dataarray for the given category. This + should contain a `{category}_mean` and `{category}_std` variable for + each variable in the category. For `category=="state"`, the dataarray + should also contain a `state_diff_mean` and `state_diff_std` variable + for the one- step differences of the state variables. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + xr.Dataset + The standardization dataarray for the given category, with + variables for the mean and standard deviation of the variables (and + differences for state variables). + + """ + + def load_pickled_tensor(fn): + return torch.load( + self.root_path / "static" / fn, weights_only=True + ).numpy() + + mean_diff_values = None + std_diff_values = None + if category == "state": + mean_values = load_pickled_tensor("parameter_mean.pt") + std_values = load_pickled_tensor("parameter_std.pt") + try: + mean_diff_values = load_pickled_tensor("diff_mean.pt") + std_diff_values = load_pickled_tensor("diff_std.pt") + except FileNotFoundError: + warnings.warn(f"Could not load diff mean/std for {category}") + # XXX: this is a hack, but when running + # compute_standardization_stats the diff mean/std files are + # created, but require the std and mean files + mean_diff_values = np.empty_like(mean_values) + std_diff_values = np.empty_like(std_values) + + elif category == "forcing": + flux_stats = load_pickled_tensor("flux_stats.pt") # (2,) + flux_mean, flux_std = flux_stats + # manually add hour sin/cos and day-of-year sin/cos stats for now + # the mean/std for open_water_fraction is hardcoded for now + mean_values = np.array([flux_mean, 0.0, 0.0, 0.0, 0.0, 0.0]) + std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0]) + + elif category == "static": + da_static = self.get_dataarray(category="static", split="train") + da_static_mean = da_static.mean(dim=["grid_index"]).compute() + da_static_std = da_static.std(dim=["grid_index"]).compute() + mean_values = da_static_mean.values + std_values = da_static_std.values + else: + raise NotImplementedError(f"Category {category} not supported") + + feature_dim_name = f"{category}_feature" + variables = { + f"{category}_mean": (feature_dim_name, mean_values), + f"{category}_std": (feature_dim_name, std_values), + } + + if mean_diff_values is not None and std_diff_values is not None: + variables["state_diff_mean"] = (feature_dim_name, mean_diff_values) + variables["state_diff_std"] = (feature_dim_name, std_diff_values) + + ds_norm = xr.Dataset( + variables, + coords={feature_dim_name: self.get_vars_names(category=category)}, + ) + + return ds_norm + + @functools.cached_property + def coords_projection(self) -> ccrs.Projection: + """The projection of the spatial coordinates. + + Returns + ------- + ccrs.Projection + The projection of the spatial coordinates. + + """ + proj_class_name = self.config.projection.class_name + ProjectionClass = getattr(ccrs, proj_class_name) + proj_params = self.config.projection.kwargs + return ProjectionClass(**proj_params) diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py new file mode 100644 index 00000000..2d477271 --- /dev/null +++ b/neural_lam/datastore/plot_example.py @@ -0,0 +1,189 @@ +# Third-party +import matplotlib.pyplot as plt + +# Local +from . import DATASTORES, init_datastore + + +def plot_example_from_datastore( + category, + datastore, + col_dim, + split="train", + standardize=True, + selection={}, + index_selection={}, +): + """ + Create a plot of the data from the datastore. + + Parameters + ---------- + category : str + Category of data to plot, one of "state", "forcing", or "static". + datastore : Datastore + Datastore to retrieve data from. + col_dim : str + Dimension to use for plot facetting into columns. This can be a + template string that can be formatted with the category name. + split : str, optional + Split of data to plot, by default "train". + standardize : bool, optional + Whether to standardize the data before plotting, by default True. + selection : dict, optional + Selections to apply to the dataarray, for example + `time="1990-09-03T0:00" would select this single timestep, by default + {}. + index_selection: dict, optional + Index-based selection to apply to the dataarray, for example + `time=0` would select the first item along the `time` dimension, by + default {}. + + Returns + ------- + Figure + Matplotlib figure object. + """ + da = datastore.get_dataarray(category=category, split=split) + if standardize: + da_stats = datastore.get_standardization_dataarray(category=category) + da = (da - da_stats[f"{category}_mean"]) / da_stats[f"{category}_std"] + da = datastore.unstack_grid_coords(da) + + if len(selection) > 0: + da = da.sel(**selection) + if len(index_selection) > 0: + da = da.isel(**index_selection) + + col = col_dim.format(category=category) + + # check that the column dimension exists and that the resulting shape is 2D + if col not in da.dims: + raise ValueError(f"Column dimension {col} not found in dataarray.") + da_col_item = da.isel({col: 0}).squeeze() + if not len(da_col_item.shape) == 2: + raise ValueError( + f"Column dimension {col} and selection {selection} does not " + "result in a 2D dataarray. Please adjust the column dimension " + "and/or selection. Instead the resulting dataarray is:\n" + f"{da_col_item}" + ) + + crs = datastore.coords_projection + col_wrap = min(4, int(da[col].count())) + g = da.plot( + x="x", + y="y", + col=col, + col_wrap=col_wrap, + subplot_kws={"projection": crs}, + transform=crs, + size=4, + ) + for ax in g.axes.flat: + ax.coastlines() + ax.gridlines(draw_labels=["left", "bottom"]) + ax.set_extent(datastore.get_xy_extent(category=category), crs=crs) + + return g.fig + + +if __name__ == "__main__": + # Standard library + import argparse + + def _parse_dict(arg_str): + key, value = arg_str.split("=") + for op in [int, float]: + try: + value = op(value) + break + except ValueError: + pass + return key, value + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--datastore_kind", + type=str, + choices=DATASTORES.keys(), + default="mdp", + help="Kind of datastore to use", + ) + parser.add_argument( + "--datastore_config_path", + type=str, + default=None, + help="Path for the datastore config", + ) + parser.add_argument( + "--category", + default="state", + help="Category of data to plot", + choices=["state", "forcing", "static"], + ) + parser.add_argument( + "--split", default="train", help="Split of data to plot" + ) + parser.add_argument( + "--col-dim", + default="{category}_feature", + help="Dimension to use for plot facetting into columns", + ) + parser.add_argument( + "--disable-standardize", + dest="standardize", + action="store_false", + help="Disable standardization of data", + ) + # add the ability to create dictionary of kwargs + parser.add_argument( + "--selection", + nargs="+", + default=[], + type=_parse_dict, + help="Selections to apply to the dataarray, for example " + "`time='1990-09-03T0:00' would select this single timestep", + ) + parser.add_argument( + "--index-selection", + nargs="+", + default=[], + type=_parse_dict, + help="Index-based selection to apply to the dataarray, for example " + "`time=0` would select the first item along the `time` dimension", + ) + args = parser.parse_args() + + assert ( + args.datastore_config_path is not None + ), "Specify your datastore config with --datastore_config_path" + + selection = dict(args.selection) + index_selection = dict(args.index_selection) + + # check that column dimension is not in the selection + if args.col_dim.format(category=args.category) in selection: + raise ValueError( + f"Column dimension {args.col_dim.format(category=args.category)} " + f"cannot be in the selection ({selection}). Please adjust the " + "column dimension and/or selection." + ) + + datastore = init_datastore( + datastore_kind=args.datastore_kind, + config_path=args.datastore_config_path, + ) + + plot_example_from_datastore( + args.category, + datastore, + split=args.split, + col_dim=args.col_dim, + standardize=args.standardize, + selection=selection, + index_selection=index_selection, + ) + plt.show() diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index 663f27e4..14a0d1c7 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -3,8 +3,8 @@ import torch_geometric as pyg from torch import nn -# First-party -from neural_lam import utils +# Local +from . import utils class InteractionNet(pyg.nn.MessagePassing): @@ -25,12 +25,14 @@ def __init__( hidden_dim=None, edge_chunk_sizes=None, aggr_chunk_sizes=None, + num_rec=None, aggr="sum", ): """ Create a new InteractionNet - edge_index: (2,M), Edges in pyg format + edge_index: (2,M), Edges in pyg format, with both sender and receiver + node indices starting at 0 input_dim: Dimensionality of input representations, for both nodes and edges update_edges: If new edge representations should be computed @@ -43,6 +45,8 @@ def __init__( aggr_chunk_sizes: List of chunks sizes to split aggregated node representation into and use separate MLPs for (None = no chunking, same MLP) + num_rec: Number of receiver nodes. If None, derive from edge_index under + assumption that all receiver nodes have at least one incoming edge. aggr: Message aggregation method (sum/mean) """ assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}" @@ -52,12 +56,16 @@ def __init__( # Default to input dim if not explicitly given hidden_dim = input_dim - # Make both sender and receiver indices of edge_index start at 0 - edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] # Store number of receiver nodes according to edge_index - self.num_rec = edge_index[1].max() + 1 - edge_index[0] = ( - edge_index[0] + self.num_rec + if num_rec is None: + # Derive from edge_index + self.num_rec = edge_index[1].max() + 1 + else: + self.num_rec = num_rec + + # any edge_index used here must start sender and rec. nodes at index 0 + edge_index = torch.stack( + (edge_index[0] + self.num_rec, edge_index[1]), dim=0 ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) diff --git a/neural_lam/loss_weighting.py b/neural_lam/loss_weighting.py new file mode 100644 index 00000000..c842b202 --- /dev/null +++ b/neural_lam/loss_weighting.py @@ -0,0 +1,106 @@ +# Local +from .config import ( + ManualStateFeatureWeighting, + NeuralLAMConfig, + UniformFeatureWeighting, +) +from .datastore.base import BaseDatastore + + +def get_manual_state_feature_weights( + weighting_config: ManualStateFeatureWeighting, datastore: BaseDatastore +) -> list[float]: + """ + Return the state feature weights as a list of floats in the order of the + state features in the datastore. + + Parameters + ---------- + weighting_config : ManualStateFeatureWeighting + Configuration object containing the manual state feature weights. + datastore : BaseDatastore + Datastore object containing the state features. + + Returns + ------- + list[float] + List of floats containing the state feature weights. + """ + state_feature_names = datastore.get_vars_names(category="state") + feature_weight_names = weighting_config.weights.keys() + + # Check that the state_feature_weights dictionary has a weight for each + # state feature in the datastore. + if set(feature_weight_names) != set(state_feature_names): + additional_features = set(feature_weight_names) - set( + state_feature_names + ) + missing_features = set(state_feature_names) - set(feature_weight_names) + raise ValueError( + f"State feature weights must be provided for each state feature" + f"in the datastore ({state_feature_names}). {missing_features}" + " are missing and weights are defined for the features " + f"{additional_features} which are not in the datastore." + ) + + state_feature_weights = [ + weighting_config.weights[feature] for feature in state_feature_names + ] + return state_feature_weights + + +def get_uniform_state_feature_weights(datastore: BaseDatastore) -> list[float]: + """ + Return the state feature weights as a list of floats in the order of the + state features in the datastore. + + The weights are uniform, i.e. 1.0/n_features for each feature. + + Parameters + ---------- + datastore : BaseDatastore + Datastore object containing the state features. + + Returns + ------- + list[float] + List of floats containing the state feature weights. + """ + state_feature_names = datastore.get_vars_names(category="state") + n_features = len(state_feature_names) + return [1.0 / n_features] * n_features + + +def get_state_feature_weighting( + config: NeuralLAMConfig, datastore: BaseDatastore +) -> list[float]: + """ + Return the state feature weights as a list of floats in the order of the + state features in the datastore. The weights are determined based on the + configuration in the NeuralLAMConfig object. + + Parameters + ---------- + config : NeuralLAMConfig + Configuration object for neural-lam. + datastore : BaseDatastore + Datastore object containing the state features. + + Returns + ------- + list[float] + List of floats containing the state feature weights. + """ + weighting_config = config.training.state_feature_weighting + + if isinstance(weighting_config, ManualStateFeatureWeighting): + weights = get_manual_state_feature_weights(weighting_config, datastore) + elif isinstance(weighting_config, UniformFeatureWeighting): + weights = get_uniform_state_feature_weights(datastore) + else: + raise NotImplementedError( + "Unsupported state feature weighting configuration: " + f"{weighting_config}" + ) + + return weights diff --git a/neural_lam/models/__init__.py b/neural_lam/models/__init__.py new file mode 100644 index 00000000..f65387ab --- /dev/null +++ b/neural_lam/models/__init__.py @@ -0,0 +1,6 @@ +# Local +from .base_graph_model import BaseGraphModel +from .base_hi_graph_model import BaseHiGraphModel +from .graph_lam import GraphLAM +from .hi_lam import HiLAM +from .hi_lam_parallel import HiLAMParallel diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f49eb094..b0a6fbba 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,9 +8,14 @@ import pytorch_lightning as pl import torch import wandb +import xarray as xr -# First-party -from neural_lam import metrics, utils, vis +# Local +from .. import metrics, vis +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -21,52 +27,148 @@ class ARModel(pl.LightningModule): # pylint: disable=arguments-differ # Disable to override args/kwargs from superclass - def __init__(self, args): + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["datastore"]) self.args = args - self.config_loader = utils.ConfigLoader(args.data_config) + self._datastore = datastore + num_state_vars = datastore.get_num_data_vars(category="state") + num_forcing_vars = datastore.get_num_data_vars(category="forcing") - # Load static features for grid/data - static = self.config_loader.process_dataset("static") + num_past_forcing_steps = args.num_past_forcing_steps + num_future_forcing_steps = args.num_future_forcing_steps + + # Load static features for interior + da_static_features = datastore.get_dataarray( + category="static", split=None, standardize=True + ) self.register_buffer( - "grid_static_features", - torch.tensor(static.values), + "interior_static_features", + torch.tensor(da_static_features.values, dtype=torch.float32), persistent=False, ) + # Load stats for rescaling and weights + da_state_stats = datastore.get_standardization_dataarray( + category="state" + ) + state_stats = { + "state_mean": torch.tensor( + da_state_stats.state_mean.values, dtype=torch.float32 + ), + "state_std": torch.tensor( + da_state_stats.state_std.values, dtype=torch.float32 + ), + "diff_mean": torch.tensor( + da_state_stats.state_diff_mean.values, dtype=torch.float32 + ), + "diff_std": torch.tensor( + da_state_stats.state_diff_std.values, dtype=torch.float32 + ), + } + + for key, val in state_stats.items(): + self.register_buffer(key, val, persistent=False) + + state_feature_weights = get_state_feature_weighting( + config=config, datastore=datastore + ) + self.feature_weights = torch.tensor( + state_feature_weights, dtype=torch.float32 + ) + # Double grid output dim. to also output std.-dev. self.output_std = bool(args.output_std) if self.output_std: # Pred. dim. in grid cell - self.grid_output_dim = 2 * self.config_loader.num_data_vars("state") + self.grid_output_dim = 2 * num_state_vars else: # Pred. dim. in grid cell - self.grid_output_dim = self.config_loader.num_data_vars("state") + self.grid_output_dim = num_state_vars + # Store constant per-variable std.-dev. weighting + # NOTE that this is the inverse of the multiplicative weighting + # in wMSE/wMAE + self.register_buffer( + "per_var_std", + self.diff_std / torch.sqrt(self.feature_weights), + persistent=False, + ) - # grid_dim from data + static + # interior from data + static ( - self.num_grid_nodes, - grid_static_dim, - ) = self.grid_static_features.shape - self.grid_dim = ( - 2 * self.config_loader.num_data_vars("state") - + grid_static_dim - + self.config_loader.num_data_vars("forcing") - * self.config_loader.forcing.window + self.num_interior_nodes, + interior_static_dim, + ) = self.interior_static_features.shape + self.num_total_grid_nodes = self.num_interior_nodes + self.interior_dim = ( + 2 * self.grid_output_dim + + interior_static_dim + + num_forcing_vars + * (num_past_forcing_steps + num_future_forcing_steps + 1) ) + # If datastore_boundary is given, the model is forced from the boundary + self.boundary_forced = datastore_boundary is not None + + if self.boundary_forced: + # Load static features for boundary + da_boundary_static_features = datastore_boundary.get_dataarray( + category="static", split=None, standardize=True + ) + self.register_buffer( + "boundary_static_features", + torch.tensor( + da_boundary_static_features.values, dtype=torch.float32 + ), + persistent=False, + ) + + # Compute dimensionalities (e.g. to instantiate MLPs) + ( + self.num_boundary_nodes, + boundary_static_dim, + ) = self.boundary_static_features.shape + + # Compute boundary input dim separately + num_boundary_forcing_vars = datastore_boundary.get_num_data_vars( + category="forcing" + ) + + # Dimensionality of encoded time deltas + self.time_delta_enc_dim = ( + args.hidden_dim + if args.time_delta_enc_dim is None + else args.time_delta_enc_dim + ) + assert self.time_delta_enc_dim % 2 == 0, ( + "Number of dimensions to use for time delta encoding must be " + "even (sin and cos)" + ) + + num_past_boundary_steps = args.num_past_boundary_steps + num_future_boundary_steps = args.num_future_boundary_steps + self.boundary_dim = ( + boundary_static_dim + # Time delta counts as one additional forcing_feature + + (num_boundary_forcing_vars + self.time_delta_enc_dim) + * (num_past_boundary_steps + num_future_boundary_steps + 1) + ) + # How many of the last boundary forcing dims contain time-deltas + self.boundary_time_delta_dims = ( + num_past_boundary_steps + num_future_boundary_steps + 1 + ) + + self.num_total_grid_nodes += self.num_boundary_nodes + # Instantiate loss function self.loss = metrics.get_metric(args.loss) - border_mask = torch.zeros(self.num_grid_nodes, 1) - self.register_buffer("border_mask", border_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.border_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - - self.step_length = args.step_length # Number of hours per pred. step self.val_metrics = { "mse": [], } @@ -77,8 +179,8 @@ def __init__(self, args): if self.output_std: self.test_metrics["output_std"] = [] # Treat as metric - # For making restoring of optimizer state optional (slight hack) - self.opt_state = None + # For making restoring of optimizer state optional + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -87,35 +189,70 @@ def __init__(self, args): # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] - # Load normalization statistics - self.normalization_stats = self.config_loader.load_normalization_stats() - if self.normalization_stats is not None: - for ( - var_name, - var_data, - ) in self.normalization_stats.data_vars.items(): - self.register_buffer( - f"{var_name}", - torch.tensor(var_data.values), - persistent=False, + # Set if grad checkpointing function should be used during rollout + if args.grad_checkpointing: + # Perform gradient checkpointing at each unrolling step + self.unroll_ckpt_func = ( + lambda f, *args: torch.utils.checkpoint.checkpoint( + f, *args, use_reentrant=False ) + ) + else: + self.unroll_ckpt_func = lambda f, *args: f(*args) + + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion + weather_dataset = WeatherDataset( + datastore=self._datastore, + datastore_boundary=None, + split=split, + ) + + # Move to CPU if on GPU + time = time.detach().cpu() + time = np.array(time, dtype="datetime64[ns]") + + tensor = tensor.detach().cpu() + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) - if self.opt_state: - opt.load_state_dict(self.opt_state) - return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -123,59 +260,65 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t - prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} - forcing: (B, num_grid_nodes, forcing_dim) + prev_state: (B, num_interior_nodes, feature_dim), X_t + prev_prev_state: (B, num_interior_nodes, feature_dim), X_{t-1} + forcing: (B, num_interior_nodes, forcing_dim) + boundary_forcing: (B, num_boundary_nodes, boundary_forcing_dim) """ raise NotImplementedError("No prediction step implemented") - def unroll_prediction(self, init_states, forcing_features, true_states): + def unroll_prediction(self, init_states, forcing, boundary_forcing): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) - forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) - true_states: (B, pred_steps, num_grid_nodes, d_f) + init_states: (B, 2, num_interior_nodes, d_f) + forcing: (B, pred_steps, num_interior_nodes, d_static_f) + boundary_forcing: (B, pred_steps, num_boundary_nodes, d_boundary_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] prediction_list = [] pred_std_list = [] - pred_steps = forcing_features.shape[1] + pred_steps = forcing.shape[1] for i in range(pred_steps): - forcing = forcing_features[:, i] - border_state = true_states[:, i] - - pred_state, pred_std = self.predict_step( - prev_state, prev_prev_state, forcing + forcing_step = forcing[:, i] + + if self.boundary_forced: + boundary_forcing_step = boundary_forcing[:, i] + else: + boundary_forcing_step = None + + pred_state, pred_std = self.unroll_ckpt_func( + self.predict_step, + prev_state, + prev_prev_state, + forcing_step, + boundary_forcing_step, ) - # state: (B, num_grid_nodes, d_f) - # pred_std: (B, num_grid_nodes, d_f) or None + # state: (B, num_interior_nodes, d_f) + # pred_std: (B, num_interior_nodes, d_f) or None - # Overwrite border with true state - new_state = ( - self.border_mask * border_state - + self.interior_mask * pred_state - ) + prediction_list.append(pred_state) - prediction_list.append(new_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) + ) # (B, pred_steps, num_interior_nodes, d_f) if self.output_std: pred_std = torch.stack( pred_std_list, dim=1 - ) # (B, pred_steps, num_grid_nodes, d_f) + ) # (B, pred_steps, num_interior_nodes, d_f) else: pred_std = self.diff_std # (d_f,) @@ -185,24 +328,28 @@ def common_step(self, batch): """ Predict on single batch batch consists of: - init_states: (B, 2, num_grid_nodes, d_features) - target_states: (B, pred_steps, num_grid_nodes, d_features) - forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), + init_states: (B, 2, num_interior_nodes, d_features) + target_states: (B, pred_steps, num_interior_nodes, d_features) + forcing: (B, pred_steps, num_interior_nodes, d_forcing), + boundary_forcing: + (B, pred_steps, num_boundary_nodes, d_boundary_forcing), where index 0 corresponds to index 1 of init_states """ ( init_states, target_states, - forcing_features, + forcing, + boundary_forcing, + batch_times, ) = batch prediction, pred_std = self.unroll_prediction( - init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) - # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + init_states, forcing, boundary_forcing + ) # (B, pred_steps, num_interior_nodes, d_f) + # prediction: (B, pred_steps, num_interior_nodes, d_f) pred_std: (B, + # pred_steps, num_interior_nodes, d_f) or (d_f,) - return prediction, target_states, pred_std + return prediction, target_states, pred_std, batch_times def on_after_batch_transfer(self, batch, dataloader_idx): """Normalize Batch data after transferring to the device.""" @@ -228,25 +375,32 @@ def training_step(self, batch): """ Train on single batch """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) return batch_loss def all_gather_cat(self, tensor_to_gather): """ - Gather tensors across all ranks, and concatenate across dim. 0 - (instead of stacking in new dim. 0) + Gather tensors across all ranks, and concatenate across dim. 0 (instead + of stacking in new dim. 0) tensor_to_gather: (d1, d2, ...), distributed over K ranks @@ -260,11 +414,13 @@ def validation_step(self, batch, batch_idx): """ Run validation on single batch """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ), dim=0, ) # (time_steps-1) @@ -273,11 +429,16 @@ def validation_step(self, batch, batch_idx): # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log + if step <= len(time_step_loss) } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( - val_log_dict, on_step=False, on_epoch=True, sync_dist=True + val_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) # Store MSEs @@ -285,7 +446,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -306,13 +466,16 @@ def test_step(self, batch, batch_idx): """ Run test on single batch """ - prediction, target, pred_std = self.common_step(batch) - # prediction: (B, pred_steps, num_grid_nodes, d_f) - # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + # TODO Here batch_times can be used for plotting routines + prediction, target, pred_std, batch_times = self.common_step(batch) + # prediction: (B, pred_steps, num_interior_nodes, d_f) pred_std: (B, + # pred_steps, num_interior_nodes, d_f) or (d_f,) time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ), dim=0, ) # (time_steps-1,) @@ -321,43 +484,45 @@ def test_step(self, batch, batch_idx): # Log loss per time step forward and mean test_log_dict = { f"test_loss_unroll{step}": time_step_loss[step - 1] - for step in self.args.val_steps_log + for step in self.args.val_steps_to_log } test_log_dict["test_mean_loss"] = mean_loss self.log_dict( - test_log_dict, on_step=False, on_epoch=True, sync_dist=True + test_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) - # Compute all evaluation metrics for error maps - # Note: explicitly list metrics here, as test_metrics can contain - # additional ones, computed differently, but that should be aggregated - # on_test_epoch_end + # Compute all evaluation metrics for error maps Note: explicitly list + # metrics here, as test_metrics can contain additional ones, computed + # differently, but that should be aggregated on_test_epoch_end for metric_name in ("mse", "mae"): metric_func = metrics.get_metric(metric_name) batch_metric_vals = metric_func( prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times spatial_loss = self.loss( prediction, target, pred_std, average_grid=False - ) # (B, pred_steps, num_grid_nodes) - log_spatial_losses = spatial_loss[:, self.args.val_steps_log - 1] + ) # (B, pred_steps, num_interior_nodes) + log_spatial_losses = spatial_loss[ + :, [step - 1 for step in self.args.val_steps_to_log] + ] self.spatial_loss_maps.append(log_spatial_losses) - # (B, N_log, num_grid_nodes) + # (B, N_log, num_interior_nodes) # Plot example predictions (on rank 0 only) if ( @@ -366,38 +531,58 @@ def test_step(self, batch, batch_idx): ): # Need to plot more example predictions n_additional_examples = min( - prediction.shape[0], self.n_example_pred - self.plotted_examples + prediction.shape[0], + self.n_example_pred - self.plotted_examples, ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch batch: batch with data to plot corresponding forecasts for n_examples: number of forecasts to plot - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. - Generate if None. + prediction: (B, pred_steps, num_interior_nodes, d_f), + existing prediction. Generate if None. """ if prediction is None: - prediction, target = self.common_step(batch) + prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[-1] # Rescale to original data scale - prediction_rescaled = prediction * self.std + self.mean - target_rescaled = target * self.std + self.mean + prediction_rescaled = prediction * self.state_std + self.state_mean + target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): - # Each slice is (pred_steps, num_grid_nodes, d_f) + # Each slice is (pred_steps, num_interior_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -417,35 +602,37 @@ def plot_examples(self, batch, n_examples, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred_t[:, var_i], - target_t[:, var_i], - self.interior_mask[:, 0], - self.config_loader, + datastore=self._datastore, title=f"{var_name} ({var_unit}), " - f"t={t_i} ({self.step_length * t_i} h)", + f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( - self.config_loader.param_names(), - self.config_loader.param_units(), + self._datastore.get_vars_names("state"), + self._datastore.get_vars_units("state"), var_vranges, ) ) ] example_i = self.plotted_examples + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) for var_name, fig in zip( - self.config_loader.param_names(), var_figs + self._datastore.get_vars_names("state"), var_figs ) } ) @@ -469,19 +656,19 @@ def plot_examples(self, batch, n_examples, prediction=None): def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ - Put together a dict with everything to log for one metric. - Also saves plots as pdf and csv if using test prefix. + Put together a dict with everything to log for one metric. Also saves + plots as pdf and csv if using test prefix. metric_tensor: (pred_steps, d_f), metric values per time and variable - prefix: string, prefix to use for logging - metric_name: string, name of the metric + prefix: string, prefix to use for logging metric_name: string, name of + the metric - Return: - log_dict: dict with everything to log for given metric + Return: log_dict: dict with everything to log for given metric """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, self.config_loader, step_length=self.step_length + errors=metric_tensor, + datastore=self._datastore, ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -499,17 +686,13 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): ) # Check if metrics are watched, log exact values for specific vars + var_names = self._datastore.get_vars_names(category="state") if full_log_name in self.args.metrics_watch: for var_i, timesteps in self.args.var_leads_metrics_watch.items(): - var = self.config_loader.param_names()[var_i] - log_dict.update( - { - f"{full_log_name}_{var}_step_{step}": metric_tensor[ - step - 1, var_i - ] # 1-indexed in data_config - for step in timesteps - } - ) + var_name = var_names[var_i] + for step in timesteps: + key = f"{full_log_name}_{var_name}_step_{step}" + log_dict[key] = metric_tensor[step - 1, var_i] return log_dict @@ -536,8 +719,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) metric_name = metric_name.replace("mse", "rmse") - # Note: we here assume rescaling for all metrics is linear - metric_rescaled = metric_tensor_averaged * self.std + # NOTE: we here assume rescaling for all metrics is linear + metric_rescaled = metric_tensor_averaged * self.state_std # (pred_steps, d_f) log_dict.update( self.create_metric_log_dict( @@ -551,8 +734,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix): def on_test_epoch_end(self): """ - Compute test metrics and make plots at the end of test epoch. - Will gather stored tensors and perform plotting and logging on rank 0. + Compute test metrics and make plots at the end of test epoch. Will + gather stored tensors and perform plotting and logging on rank 0. """ # Create error maps for all test metrics self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") @@ -560,21 +743,21 @@ def on_test_epoch_end(self): # Plot spatial loss maps spatial_loss_tensor = self.all_gather_cat( torch.cat(self.spatial_loss_maps, dim=0) - ) # (N_test, N_log, num_grid_nodes) + ) # (N_test, N_log, num_interior_nodes) if self.trainer.is_global_zero: mean_spatial_loss = torch.mean( spatial_loss_tensor, dim=0 - ) # (N_log, num_grid_nodes) + ) # (N_log, num_interior_nodes) loss_map_figs = [ vis.plot_spatial_error( - loss_map, - self.interior_mask[:, 0], - self.config_loader, - title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", + error=loss_map, + datastore=self._datastore, + title=f"Test loss, t={t_i} " + f"({self._datastore.step_length * t_i} h)", ) for t_i, loss_map in zip( - self.args.val_steps_log, mean_spatial_loss + self.args.val_steps_to_log, mean_spatial_loss ) ] @@ -585,13 +768,13 @@ def on_test_epoch_end(self): # also make without title and save as pdf pdf_loss_map_figs = [ vis.plot_spatial_error( - loss_map, self.interior_mask[:, 0], self.config_loader + error=loss_map, datastore=self._datastore ) for loss_map in mean_spatial_loss ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs): + for t_i, fig in zip(self.args.val_steps_to_log, pdf_loss_map_figs): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also torch.save( @@ -622,3 +805,6 @@ def on_load_checkpoint(self, checkpoint): ) loaded_state_dict[new_key] = loaded_state_dict[old_key] del loaded_state_dict[old_key] + if not self.restore_opt: + opt = self.configure_optimizers() + checkpoint["optimizer_states"] = [opt.state_dict()] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fb5df62d..0e004935 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -1,10 +1,15 @@ +# Standard library +from typing import Union + # Third-party import torch -# First-party -from neural_lam import utils -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.ar_model import ARModel +# Local +from .. import utils +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..interaction_net import InteractionNet +from .ar_model import ARModel class BaseGraphModel(ARModel): @@ -13,14 +18,33 @@ class BaseGraphModel(ARModel): the encode-process-decode idea. """ - def __init__(self, args): - super().__init__(args) + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): + super().__init__( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) # Load graph with static features - # NOTE: (IMPORTANT!) mesh nodes MUST have the first - # num_mesh_nodes indices, - self.hierarchical, graph_ldict = utils.load_graph(args.graph) + graph_dir_path = datastore.root_path / "graphs" / args.graph_name + self.hierarchical, graph_ldict = utils.load_graph( + graph_dir_path=graph_dir_path + ) for name, attr_value in graph_ldict.items(): + # NOTE: It would be good to rescale mesh node position features in + # exactly the same way as grid node position static features. + if name == "mesh_static_features": + max_coord = datastore.get_xy("state").max() + # Rescale by dividing by maximum coordinate in interior + attr_value /= max_coord + # Make BufferLists module members and register tensors as buffers if isinstance(attr_value, torch.Tensor): self.register_buffer(name, attr_value, persistent=False) @@ -28,22 +52,40 @@ def __init__(self, args): setattr(self, name, attr_value) # Specify dimensions of data - self.num_mesh_nodes, _ = self.get_num_mesh() print( - f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " - f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" + "Loaded graph with " + f"{self.num_total_grid_nodes + self.num_mesh_nodes} " + f"nodes ({self.num_total_grid_nodes} grid, " + f"{self.num_mesh_nodes} mesh)" ) - # grid_dim from data + static + # interior_dim from data + static self.g2m_edges, g2m_dim = self.g2m_features.shape self.m2g_edges, m2g_dim = self.m2g_features.shape # Define sub-models - # Feature embedders for grid + # Feature embedders for interior self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) - self.grid_embedder = utils.make_mlp( - [self.grid_dim] + self.mlp_blueprint_end + self.interior_embedder = utils.make_mlp( + [self.interior_dim] + self.mlp_blueprint_end ) + + if self.boundary_forced: + # Define embedder for boundary nodes + # Optional separate embedder for boundary nodes + if args.shared_grid_embedder: + assert self.interior_dim == self.boundary_dim, ( + "Grid and boundary input dimension must " + "be the same when using " + f"the same embedder, got interior_dim={self.interior_dim}, " + f"boundary_dim={self.boundary_dim}" + ) + self.boundary_embedder = self.interior_embedder + else: + self.boundary_embedder = utils.make_mlp( + [self.boundary_dim] + self.mlp_blueprint_end + ) + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -54,6 +96,7 @@ def __init__(self, args): args.hidden_dim, hidden_layers=args.hidden_layers, update_edges=False, + num_rec=self.num_grid_connected_mesh_nodes, ) self.encoding_grid_mlp = utils.make_mlp( [args.hidden_dim] + self.mlp_blueprint_end @@ -65,6 +108,7 @@ def __init__(self, args): args.hidden_dim, hidden_layers=args.hidden_layers, update_edges=False, + num_rec=self.num_interior_nodes, ) # Output mapping (hidden_dim -> output_dim) @@ -74,12 +118,227 @@ def __init__(self, args): layer_norm=False, ) # No layer norm on this one - def get_num_mesh(self): + # Compute constants for use in time_delta encoding + step_length_ratio = ( + datastore_boundary.step_length / datastore.step_length + ) + min_time_delta = -(args.num_past_boundary_steps + 1) * step_length_ratio + max_time_delta = args.num_future_boundary_steps * step_length_ratio + time_delta_magnitude = max(max_time_delta, abs(min_time_delta)) + + freq_indices = 1.0 + torch.arange( + self.time_delta_enc_dim // 2, + dtype=torch.float, + ) + self.register_buffer( + "enc_freq_denom", + (2 * time_delta_magnitude) + ** (2 * freq_indices / self.time_delta_enc_dim), + persistent=False, + ) + + # Compute indices and define clamping functions + self.prepare_clamping_params(config, datastore) + + @property + def num_mesh_nodes(self): + """ + Get the total number of mesh nodes in the used mesh graph + """ + raise NotImplementedError("num_mesh_nodes not implemented") + + def prepare_clamping_params( + self, config: NeuralLAMConfig, datastore: BaseDatastore + ): + """ + Prepare parameters for clamping predicted values to valid range + """ + + # Read configs + state_feature_names = datastore.get_vars_names(category="state") + lower_lims = config.training.output_clamping.lower + upper_lims = config.training.output_clamping.upper + + # Check that limits in config are for valid features + unknown_features_lower = set(lower_lims.keys()) - set( + state_feature_names + ) + unknown_features_upper = set(upper_lims.keys()) - set( + state_feature_names + ) + if unknown_features_lower or unknown_features_upper: + raise ValueError( + "State feature limits were provided for unknown features: " + f"{unknown_features_lower.union(unknown_features_upper)}" + ) + + # Constant parameters for clamping + sigmoid_sharpness = 1 + softplus_sharpness = 1 + sigmoid_center = 0 + softplus_center = 0 + + normalize_clamping_lim = ( + lambda x, feature_idx: (x - self.state_mean[feature_idx]) + / self.state_std[feature_idx] + ) + + # Check which clamping functions to use for each feature + sigmoid_lower_upper_idx = [] + sigmoid_lower_lims = [] + sigmoid_upper_lims = [] + + softplus_lower_idx = [] + softplus_lower_lims = [] + + softplus_upper_idx = [] + softplus_upper_lims = [] + + for feature_idx, feature in enumerate(state_feature_names): + if feature in lower_lims and feature in upper_lims: + assert ( + lower_lims[feature] < upper_lims[feature] + ), f'Invalid clamping limits for feature "{feature}",\ + lower: {lower_lims[feature]}, larger than\ + upper: {upper_lims[feature]}' + sigmoid_lower_upper_idx.append(feature_idx) + sigmoid_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + sigmoid_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + elif feature in lower_lims and feature not in upper_lims: + softplus_lower_idx.append(feature_idx) + softplus_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + elif feature not in lower_lims and feature in upper_lims: + softplus_upper_idx.append(feature_idx) + softplus_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + + self.register_buffer( + "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) + ) + self.register_buffer( + "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) + ) + self.register_buffer( + "softplus_lower_lims", torch.tensor(softplus_lower_lims) + ) + self.register_buffer( + "softplus_upper_lims", torch.tensor(softplus_upper_lims) + ) + + self.register_buffer( + "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) + ) + self.register_buffer( + "clamp_lower_idx", torch.tensor(softplus_lower_idx) + ) + self.register_buffer( + "clamp_upper_idx", torch.tensor(softplus_upper_idx) + ) + + # Define clamping functions + self.clamp_lower_upper = lambda x: ( + self.sigmoid_lower_lims + + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) + ) + self.clamp_lower = lambda x: ( + self.softplus_lower_lims + + torch.nn.functional.softplus( + x - softplus_center, beta=softplus_sharpness + ) + ) + self.clamp_upper = lambda x: ( + self.softplus_upper_lims + - torch.nn.functional.softplus( + softplus_center - x, beta=softplus_sharpness + ) + ) + + self.inverse_clamp_lower_upper = lambda x: ( + sigmoid_center + + utils.inverse_sigmoid( + (x - self.sigmoid_lower_lims) + / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + ) + / sigmoid_sharpness + ) + self.inverse_clamp_lower = lambda x: ( + utils.inverse_softplus( + x - self.softplus_lower_lims, beta=softplus_sharpness + ) + + softplus_center + ) + self.inverse_clamp_upper = lambda x: ( + -utils.inverse_softplus( + self.softplus_upper_lims - x, beta=softplus_sharpness + ) + + softplus_center + ) + + def get_clamped_new_state(self, state_delta, prev_state): """ - Compute number of mesh nodes from loaded features, - and number of mesh nodes that should be ignored in encoding/decoding + Clamp prediction to valid range supplied in config + Returns the clamped new state after adding delta to original state + + Instead of the new state being computed as + $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ + The clamped values will be + $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ + Which means the model will learn to output values in the range of the + inverse clamping function + + state_delta: (B, num_grid_nodes, feature_dim) + prev_state: (B, num_grid_nodes, feature_dim) """ - raise NotImplementedError("get_num_mesh not implemented") + + # Assign new state, but overwrite clamped values of each type later + new_state = prev_state + state_delta + + # Sigmoid/logistic clamps between ]a,b[ + if self.clamp_lower_upper_idx.numel() > 0: + idx = self.clamp_lower_upper_idx + + new_state[:, :, idx] = self.clamp_lower_upper( + self.inverse_clamp_lower_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]a,infty[ + if self.clamp_lower_idx.numel() > 0: + idx = self.clamp_lower_idx + + new_state[:, :, idx] = self.clamp_lower( + self.inverse_clamp_lower(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + # Softplus clamps between ]-infty,b[ + if self.clamp_upper_idx.numel() > 0: + idx = self.clamp_upper_idx + + new_state[:, :, idx] = self.clamp_upper( + self.inverse_clamp_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] + ) + + return new_state + + @property + def num_grid_connected_mesh_nodes(self): + """ + Get the total number of mesh nodes that have a connection to + the grid (e.g. bottom level in a hierarchy) + """ + raise NotImplementedError( + "num_grid_connected_mesh_nodes not implemented" + ) def embedd_mesh_nodes(self): """ @@ -98,46 +357,80 @@ def process_step(self, mesh_rep): """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state, forcing): + def predict_step( + self, prev_state, prev_prev_state, forcing, boundary_forcing + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t - prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} - forcing: (B, num_grid_nodes, forcing_dim) + prev_state: (B, num_interior_nodes, feature_dim), X_t + prev_prev_state: (B, num_interior_nodes, feature_dim), X_{t-1} + forcing: (B, num_interior_nodes, forcing_dim) + boundary_forcing: (B, num_boundary_nodes, boundary_forcing_dim) """ batch_size = prev_state.shape[0] - # Create full grid node features of shape (B, num_grid_nodes, grid_dim) - grid_features = torch.cat( + # Create full interior node features of shape + # (B, num_interior_nodes, interior_dim) + interior_features = torch.cat( ( prev_state, prev_prev_state, forcing, - self.expand_to_batch(self.grid_static_features, batch_size), + self.expand_to_batch(self.interior_static_features, batch_size), ), dim=-1, ) + if self.boundary_forced: + # sin-encode time deltas for boundary forcing + boundary_forcing = self.encode_forcing_time_deltas(boundary_forcing) + + # Create full boundary node features of shape + # (B, num_boundary_nodes, boundary_dim) + boundary_features = torch.cat( + ( + boundary_forcing, + self.expand_to_batch( + self.boundary_static_features, batch_size + ), + ), + dim=-1, + ) + + # Embed boundary features + boundary_emb = self.boundary_embedder(boundary_features) + # (B, num_boundary_nodes, d_h) + # Embed all features - grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) + interior_emb = self.interior_embedder( + interior_features + ) # (B, num_interior_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() + if self.boundary_forced: + # Merge interior and boundary emb into input embedding + # We enforce ordering (interior, boundary) of nodes + full_grid_emb = torch.cat((interior_emb, boundary_emb), dim=1) + else: + # Only maps from interior to mesh + full_grid_emb = interior_emb + # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( mesh_emb, batch_size ) # (B, num_mesh_nodes, d_h) g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size) - # This also splits representation into grid and mesh + # Encode to mesh mesh_rep = self.g2m_gnn( - grid_emb, mesh_emb_expanded, g2m_emb_expanded + full_grid_emb, mesh_emb_expanded, g2m_emb_expanded ) # (B, num_mesh_nodes, d_h) # Also MLP with residual for grid representation - grid_rep = grid_emb + self.encoding_grid_mlp( - grid_emb - ) # (B, num_grid_nodes, d_h) + grid_rep = interior_emb + self.encoding_grid_mlp( + interior_emb + ) # (B, num_interior_nodes, d_h) # Run processor step mesh_rep = self.process_step(mesh_rep) @@ -146,18 +439,18 @@ def predict_step(self, prev_state, prev_prev_state, forcing): m2g_emb_expanded = self.expand_to_batch(m2g_emb, batch_size) grid_rep = self.m2g_gnn( mesh_rep, grid_rep, m2g_emb_expanded - ) # (B, num_grid_nodes, d_h) + ) # (B, num_interior_nodes, d_h) # Map to output dimension, only for grid net_output = self.output_map( grid_rep - ) # (B, num_grid_nodes, d_grid_out) + ) # (B, num_interior_nodes, d_grid_out) if self.output_std: pred_delta_mean, pred_std_raw = net_output.chunk( 2, dim=-1 - ) # both (B, num_grid_nodes, d_f) - # Note: The predicted std. is not scaled in any way here + ) # both (B, num_interior_nodes, d_f) + # NOTE: The predicted std. is not scaled in any way here # linter for some reason does not think softplus is callable # pylint: disable-next=not-callable pred_std = torch.nn.functional.softplus(pred_std_raw) @@ -168,5 +461,57 @@ def predict_step(self, prev_state, prev_prev_state, forcing): # Rescale with one-step difference statistics rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean - # Residual connection for full state - return prev_state + rescaled_delta_mean, pred_std + # Clamp values to valid range (also add the delta to the previous state) + new_state = self.get_clamped_new_state(rescaled_delta_mean, prev_state) + + return new_state, pred_std + + def encode_forcing_time_deltas(self, boundary_forcing): + """ + Build sinusoidal encodings of time deltas in boundary forcing. Removes + original time delta features and replaces these with encoded sinusoidal + features, returning the full new forcing tensor. + + Parameters + ---------- + boundary_forcing : torch.Tensor + Tensor of shape (B, num_nodes, num_forcing_dims) containing boundary + forcing features. Time delta features are the last + self.boundary_time_delta_dims dimensions of the num_forcing_dims + feature dimensions. + + + Returns + ------- + encoded_forcing : torch.Tensor + Tensor of shape (B, num_nodes, num_forcing_dims'), where the + time delta features have been removed and encoded versions added. + Note that this might change the number of feature dimensions. + """ + # Extract time delta dimensions + time_deltas = boundary_forcing[..., -self.boundary_time_delta_dims :] + # (B, num_boundary_nodes, num_time_deltas) + + # Compute sinusoidal encodings + frequencies = time_deltas.unsqueeze(-1) / self.enc_freq_denom + # (B, num_boundary_nodes, num_time_deltas, num_freq) + encodings_stacked = torch.cat( + ( + torch.sin(frequencies), + torch.cos(frequencies), + ), + dim=-1, + ) + # (B, num_boundary_nodes, num_time_deltas, 2*num_freq) + + encoded_time_deltas = encodings_stacked.flatten(-2, -1) + # (B, num_boundary_nodes, num_encoding_dims) + + # Put together encoded time deltas with rest of boundary_forcing + return torch.cat( + ( + boundary_forcing[..., : -self.boundary_time_delta_dims], + encoded_time_deltas, + ), + dim=-1, + ) diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index 8ce87030..59cdcc8e 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -1,10 +1,15 @@ +# Standard library +from typing import Union + # Third-party from torch import nn -# First-party -from neural_lam import utils -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.base_graph_model import BaseGraphModel +# Local +from .. import utils +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..interaction_net import InteractionNet +from .base_graph_model import BaseGraphModel class BaseHiGraphModel(BaseGraphModel): @@ -12,8 +17,19 @@ class BaseHiGraphModel(BaseGraphModel): Base class for hierarchical graph models. """ - def __init__(self, args): - super().__init__(args) + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): + super().__init__( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) # Track number of nodes, edges on each level # Flatten lists for efficient embedding @@ -36,7 +52,7 @@ def __init__(self, args): if level_index < (self.num_levels - 1): up_edges = self.mesh_up_features[level_index].shape[0] down_edges = self.mesh_down_features[level_index].shape[0] - print(f" {level_index}<->{level_index+1}") + print(f" {level_index}<->{level_index + 1}") print(f" - {up_edges} up edges, {down_edges} down edges") # Embedders # Assume all levels have same static feature dimensionality @@ -97,18 +113,23 @@ def __init__(self, args): ] ) - def get_num_mesh(self): + @property + def num_mesh_nodes(self): """ - Compute number of mesh nodes from loaded features, - and number of mesh nodes that should be ignored in encoding/decoding + Get the total number of mesh nodes in the used mesh graph """ num_mesh_nodes = sum( node_feat.shape[0] for node_feat in self.mesh_static_features ) - num_mesh_nodes_ignore = ( - num_mesh_nodes - self.mesh_static_features[0].shape[0] - ) - return num_mesh_nodes, num_mesh_nodes_ignore + return num_mesh_nodes + + @property + def num_grid_connected_mesh_nodes(self): + """ + Get the total number of mesh nodes that have a connection to + the grid (e.g. bottom level in a hierarchy) + """ + return self.mesh_static_features[0].shape[0] # Bottom level def embedd_mesh_nodes(self): """ @@ -179,9 +200,9 @@ def process_step(self, mesh_rep): ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - @@ -207,9 +228,9 @@ def process_step(self, mesh_rep): new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = ( - new_node_rep # (B, num_mesh_nodes[l], d_h) - ) + mesh_rep_levels[ + level_l + ] = new_node_rep # (B, num_mesh_nodes[l], d_h) # Return only bottom level representation return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index f767fba0..bd2b4b2e 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -1,10 +1,15 @@ +# Standard library +from typing import Union + # Third-party import torch_geometric as pyg -# First-party -from neural_lam import utils -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.base_graph_model import BaseGraphModel +# Local +from .. import utils +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..interaction_net import InteractionNet +from .base_graph_model import BaseGraphModel class GraphLAM(BaseGraphModel): @@ -15,14 +20,24 @@ class GraphLAM(BaseGraphModel): Oskarsson et al. (2023). """ - def __init__(self, args): - super().__init__(args) + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): + super().__init__( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) assert ( not self.hierarchical ), "GraphLAM does not use a hierarchical mesh graph" - # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape print( @@ -54,12 +69,20 @@ def __init__(self, args): ], ) - def get_num_mesh(self): + @property + def num_mesh_nodes(self): + """ + Get the total number of mesh nodes in the used mesh graph + """ + return self.mesh_static_features.shape[0] + + @property + def num_grid_connected_mesh_nodes(self): """ - Compute number of mesh nodes from loaded features, - and number of mesh nodes that should be ignored in encoding/decoding + Get the total number of mesh nodes that have a connection to + the grid (e.g. bottom level in a hierarchy) """ - return self.mesh_static_features.shape[0], 0 + return self.num_mesh_nodes # All nodes def embedd_mesh_nodes(self): """ diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index 4d7eb94c..8ab420e8 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -1,9 +1,14 @@ +# Standard library +from typing import Union + # Third-party from torch import nn -# First-party -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.base_hi_graph_model import BaseHiGraphModel +# Local +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..interaction_net import InteractionNet +from .base_hi_graph_model import BaseHiGraphModel class HiLAM(BaseHiGraphModel): @@ -13,8 +18,19 @@ class HiLAM(BaseHiGraphModel): The Hi-LAM model from Oskarsson et al. (2023) """ - def __init__(self, args): - super().__init__(args) + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): + super().__init__( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) # Make down GNNs, both for down edges and same level self.mesh_down_gnns = nn.ModuleList( @@ -200,5 +216,6 @@ def hi_processor_step( up_same_gnns, ) - # Note: We return all, even though only down edges really are used later + # NOTE: We return all, even though only down edges really are used + # later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index 740824e1..c0be48e9 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -1,10 +1,15 @@ +# Standard library +from typing import Union + # Third-party import torch import torch_geometric as pyg -# First-party -from neural_lam.interaction_net import InteractionNet -from neural_lam.models.base_hi_graph_model import BaseHiGraphModel +# Local +from ..config import NeuralLAMConfig +from ..datastore import BaseDatastore +from ..interaction_net import InteractionNet +from .base_hi_graph_model import BaseHiGraphModel class HiLAMParallel(BaseHiGraphModel): @@ -16,8 +21,19 @@ class HiLAMParallel(BaseHiGraphModel): of Hi-LAM. """ - def __init__(self, args): - super().__init__(args) + def __init__( + self, + args, + config: NeuralLAMConfig, + datastore: BaseDatastore, + datastore_boundary: Union[BaseDatastore, None], + ): + super().__init__( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) # Processor GNNs # Create the complete edge_index combining all edges for processing @@ -92,5 +108,6 @@ def hi_processor_step( self.num_levels + (self.num_levels - 1) : ] # Last are down edges - # Note: We return all, even though only down edges really are used later + # TODO: We return all, even though only down edges really are used + # later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/plot_graph.py b/neural_lam/plot_graph.py new file mode 100644 index 00000000..9eac68ad --- /dev/null +++ b/neural_lam/plot_graph.py @@ -0,0 +1,282 @@ +# Standard library +import os +from argparse import ArgumentParser + +# Third-party +import numpy as np +import plotly.graph_objects as go +import torch_geometric as pyg + +# Local +from . import utils +from .config import load_config_and_datastores + + +def main(): + """Plot graph structure in 3D using plotly.""" + parser = ArgumentParser(description="Plot graph") + parser.add_argument( + "--config_path", + type=str, + help="Path to the configuration for neural-lam", + ) + parser.add_argument( + "--graph_name", + type=str, + default="multiscale", + help="Name of saved graph to plot (default: multiscale)", + ) + parser.add_argument( + "--save", + type=str, + help="Name of .html file to save interactive plot to (default: None)", + ) + parser.add_argument( + "--show_axis", + action="store_true", + help="If the axis should be displayed (default: False)", + ) + + args = parser.parse_args() + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + + _, datastore, datastore_boundary = load_config_and_datastores( + config_path=args.config_path + ) + + # Load graph data + graph_dir_path = os.path.join( + datastore.root_path, "graphs", args.graph_name + ) + hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path) + (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( + graph_ldict["g2m_edge_index"], + graph_ldict["m2g_edge_index"], + graph_ldict["m2m_edge_index"], + ) + mesh_up_edge_index, mesh_down_edge_index = ( + graph_ldict["mesh_up_edge_index"], + graph_ldict["mesh_down_edge_index"], + ) + mesh_static_features = graph_ldict["mesh_static_features"] + + # Extract values needed, turn to numpy + # Now plotting is in the 2d CRS of datastore + grid_pos = utils.get_stacked_xy(datastore, datastore_boundary) + # (num_nodes_full, 2) + grid_scale = np.ptp(grid_pos) + + # Add in z-dimension for grid + z_grid = np.zeros((grid_pos.shape[0],)) # Grid sits at z=0 + grid_pos = np.concatenate( + (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 + ) + + # Compute z-coordinate height of mesh nodes + mesh_base_height = 0.05 * grid_scale + mesh_level_height_diff = 0.1 * grid_scale + + # List of edges to plot, (edge_index, from_pos, to_pos, color, + # line_width, label) + edge_plot_list = [] + + # Mesh positioning and edges to plot differ if we have a hierarchical graph + if hierarchical: + mesh_level_pos = [ + np.concatenate( + ( + level_static_features.numpy(), + mesh_base_height + + mesh_level_height_diff + * height_level + * np.ones((level_static_features.shape[0], 1)), + ), + axis=1, + ) + for height_level, level_static_features in enumerate( + mesh_static_features, start=1 + ) + ] + all_mesh_pos = np.concatenate(mesh_level_pos, axis=0) + grid_con_mesh_pos = mesh_level_pos[0] + + # Add inter-level mesh edges + edge_plot_list += [ + ( + level_ei.numpy(), + level_pos, + level_pos, + "blue", + 1, + f"M2M Level {level}", + ) + for level, (level_ei, level_pos) in enumerate( + zip(m2m_edge_index, mesh_level_pos) + ) + ] + + # Add intra-level mesh edges + up_edges_ei = [ + level_up_ei.numpy() for level_up_ei in mesh_up_edge_index + ] + down_edges_ei = [ + level_down_ei.numpy() for level_down_ei in mesh_down_edge_index + ] + # Add up edges + for level_i, (up_ei, from_pos, to_pos) in enumerate( + zip(up_edges_ei, mesh_level_pos[:-1], mesh_level_pos[1:]) + ): + edge_plot_list.append( + ( + up_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh up {level_i}-{level_i + 1}", + ) + ) + # Add down edges + for level_i, (down_ei, from_pos, to_pos) in enumerate( + zip(down_edges_ei, mesh_level_pos[1:], mesh_level_pos[:-1]) + ): + edge_plot_list.append( + ( + down_ei, + from_pos, + to_pos, + "green", + 1, + f"Mesh down {level_i + 1}-{level_i}", + ) + ) + + edge_plot_list.append( + ( + m2g_edge_index.numpy(), + grid_con_mesh_pos, + grid_pos, + "black", + 0.4, + "M2G", + ) + ) + edge_plot_list.append( + ( + g2m_edge_index.numpy(), + grid_pos, + grid_con_mesh_pos, + "black", + 0.4, + "G2M", + ) + ) + + mesh_node_size = 2.5 + else: + mesh_pos = mesh_static_features.numpy() + + mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy() + # 1% higher per neighbor + z_mesh = (1 + 0.01 * mesh_degrees) * mesh_base_height + mesh_node_size = mesh_degrees / 2 + + mesh_pos = np.concatenate( + (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 + ) + + edge_plot_list.append( + (m2m_edge_index.numpy(), mesh_pos, mesh_pos, "blue", 1, "M2M") + ) + edge_plot_list.append( + (m2g_edge_index.numpy(), mesh_pos, grid_pos, "black", 0.4, "M2G") + ) + edge_plot_list.append( + (g2m_edge_index.numpy(), grid_pos, mesh_pos, "black", 0.4, "G2M") + ) + + all_mesh_pos = mesh_pos + + # Add edges + data_objs = [] + for ( + ei, + from_pos, + to_pos, + col, + width, + label, + ) in edge_plot_list: + edge_start = from_pos[ei[0]] # (M, 2) + edge_end = to_pos[ei[1]] # (M, 2) + n_edges = edge_start.shape[0] + + x_edges = np.stack( + (edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), axis=1 + ).flatten() + y_edges = np.stack( + (edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), axis=1 + ).flatten() + z_edges = np.stack( + (edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), axis=1 + ).flatten() + + scatter_obj = go.Scatter3d( + x=x_edges, + y=y_edges, + z=z_edges, + mode="lines", + line={"color": col, "width": width}, + name=label, + ) + data_objs.append(scatter_obj) + + # Add node objects + + data_objs.append( + go.Scatter3d( + x=grid_pos[:, 0], + y=grid_pos[:, 1], + z=grid_pos[:, 2], + mode="markers", + marker={"color": "black", "size": 1}, + name="Grid nodes", + ) + ) + data_objs.append( + go.Scatter3d( + x=all_mesh_pos[:, 0], + y=all_mesh_pos[:, 1], + z=all_mesh_pos[:, 2], + mode="markers", + marker={"color": "blue", "size": mesh_node_size}, + name="Mesh nodes", + ) + ) + + fig = go.Figure(data=data_objs) + + fig.update_layout(scene_aspectmode="data") + fig.update_traces(connectgaps=False) + + if not args.show_axis: + # Hide axis + fig.update_layout( + scene={ + "xaxis": {"visible": False}, + "yaxis": {"visible": False}, + "zaxis": {"visible": False}, + } + ) + + if args.save: + fig.write_html(args.save, include_plotlyjs="cdn") + else: + fig.show() + + +if __name__ == "__main__": + main() diff --git a/train_model.py b/neural_lam/train_model.py similarity index 56% rename from train_model.py rename to neural_lam/train_model.py index a8b02f58..2313fda4 100644 --- a/train_model.py +++ b/neural_lam/train_model.py @@ -1,4 +1,5 @@ # Standard library +import json import random import time from argparse import ArgumentParser @@ -6,15 +7,14 @@ # Third-party import pytorch_lightning as pl import torch -import wandb from lightning_fabric.utilities import seed +from loguru import logger -# First-party -from neural_lam import utils -from neural_lam.models.graph_lam import GraphLAM -from neural_lam.models.hi_lam import HiLAM -from neural_lam.models.hi_lam_parallel import HiLAMParallel -from neural_lam.weather_dataset import WeatherDataModule +# Local +from . import utils +from .config import load_config_and_datastores +from .models import GraphLAM, HiLAM, HiLAMParallel +from .weather_dataset import WeatherDataModule MODELS = { "graph_lam": GraphLAM, @@ -23,25 +23,22 @@ } -def main(): - """ - Main function for training and evaluating models - """ +@logger.catch +def main(input_args=None): + """Main function for training and evaluating models.""" parser = ArgumentParser( description="Train or evaluate NeurWP models for LAM" ) - parser.add_argument( - "--model", + "--config_path", type=str, - default="graph_lam", - help="Model architecture to train/evaluate (default: graph_lam)", + help="Path to the configuration for neural-lam", ) parser.add_argument( - "--data_config", + "--model", type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", + default="graph_lam", + help="Model architecture to train/evaluate (default: graph_lam)", ) parser.add_argument( "--seed", type=int, default=42, help="random seed (default: 42)" @@ -52,6 +49,12 @@ def main(): default=4, help="Number of workers in data loader (default: 4)", ) + parser.add_argument( + "--num_nodes", + type=int, + default=1, + help="Number of nodes to use in DDP (default: 1)", + ) parser.add_argument( "--epochs", type=int, @@ -68,10 +71,9 @@ def main(): ) parser.add_argument( "--restore_opt", - type=int, - default=0, + action="store_true", help="If optimizer state should be restored with model " - "(default: 0 (false))", + "(default: false)", ) parser.add_argument( "--precision", @@ -79,14 +81,20 @@ def main(): default=32, help="Numerical precision to use for model (32/16/bf16) (default: 32)", ) + parser.add_argument( + "--num_sanity_steps", + type=int, + default=2, + help="Number of sanity checking validation steps to run before starting" + " training (default: 2)", + ) # Model architecture parser.add_argument( - "--graph", + "--graph_name", type=str, default="multiscale", - help="Graph to load and use in graph-based model " - "(default: multiscale)", + help="Graph to load and use in graph-based model (default: multiscale)", ) parser.add_argument( "--hidden_dim", @@ -115,27 +123,33 @@ def main(): ) parser.add_argument( "--output_std", - type=int, - default=0, + action="store_true", help="If models should additionally output std.-dev. per " "output dimensions " - "(default: 0 (no))", + "(default: False (no))", + ) + parser.add_argument( + "--shared_grid_embedder", + action="store_true", # Default to separate embedders + help="If the same embedder MLP should be used for interior and boundary" + " grid nodes. Note that this requires the same dimensionality for " + "both kinds of grid inputs. (default: False (no))", + ) + parser.add_argument( + "--time_delta_enc_dim", + type=int, + help="Dimensionality of positional encoding for time deltas of boundary" + " forcing. If None, same as hidden_dim. If given, must be even " + "(default: None)", ) # Training options parser.add_argument( "--ar_steps_train", type=int, - default=3, + default=1, help="Number of steps to unroll prediction for in loss function " - "(default: 3)", - ) - parser.add_argument( - "--control_only", - type=int, - default=0, - help="Train only on control member of ensemble data " - "(default: 0 (False))", + "(default: 1)", ) parser.add_argument( "--loss", @@ -143,13 +157,6 @@ def main(): default="wmse", help="Loss function to use, see metric.py (default: wmse)", ) - parser.add_argument( - "--step_length", - type=int, - default=1, - help="Step length in hours to consider single time step 1-3 " - "(default: 1)", - ) parser.add_argument( "--lr", type=float, default=1e-3, help="learning rate (default: 0.001)" ) @@ -160,6 +167,12 @@ def main(): help="Number of epochs training between each validation run " "(default: 1)", ) + parser.add_argument( + "--grad_checkpointing", + action="store_true", + help="If gradient checkpointing should be used in-between each " + "unrolling step (default: false)", + ) # Evaluation options parser.add_argument( @@ -171,9 +184,9 @@ def main(): parser.add_argument( "--ar_steps_eval", type=int, - default=25, - help="Number of steps to unroll prediction for in loss function " - "(default: 25)", + default=10, + help="Number of steps to unroll prediction for during evaluation " + "(default: 10)", ) parser.add_argument( "--n_example_pred", @@ -183,36 +196,79 @@ def main(): "(default: 1)", ) - # Logging Options + # Logger Settings parser.add_argument( "--wandb_project", type=str, - default="neural-lam", - help="Wandb project to log to (default: neural-lam)", + default="neural_lam", + help="Wandb project name (default: neural_lam)", ) parser.add_argument( - "--val_steps_log", - type=list, + "--val_steps_to_log", + nargs="+", + type=int, default=[1, 2, 3, 5, 10, 15, 19], - help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])", + help="Steps to log val loss for (default: 1 2 3 5 10 15 19)", ) parser.add_argument( "--metrics_watch", - type=list, + nargs="+", default=[], help="List of metrics to watch, including any prefix (e.g. val_rmse)", ) parser.add_argument( "--var_leads_metrics_watch", - type=dict, - default={}, - help="Dict with variables and lead times to log watched metrics for", + type=str, + default="{}", + help="""JSON string with variable-IDs and lead times to log watched + metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", + ) + parser.add_argument( + "--num_past_forcing_steps", + type=int, + default=1, + help="Number of past time steps to use as input for forcing data", + ) + parser.add_argument( + "--num_future_forcing_steps", + type=int, + default=1, + help="Number of future time steps to use as input for forcing data", + ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) + parser.add_argument( + "--interior_subsample_step", + type=int, + default=1, + help="Subsample step for interior grid nodes", + ) + parser.add_argument( + "--boundary_subsample_step", + type=int, + default=1, + help="Subsample step for boundary grid nodes", ) - args = parser.parse_args() + args = parser.parse_args(input_args) + args.var_leads_metrics_watch = { + int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() + } # Asserts for arguments + assert ( + args.config_path is not None + ), "Specify your config with --config_path" assert args.model in MODELS, f"Unknown model: {args.model}" - assert args.step_length <= 3, "Too high step length" assert args.eval in ( None, "val", @@ -224,10 +280,25 @@ def main(): # Set seed seed.seed_everything(args.seed) + + # Load neural-lam configuration and datastore to use + config, datastore, datastore_boundary = load_config_and_datastores( + config_path=args.config_path + ) + # Create datamodule data_module = WeatherDataModule( + datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, + standardize=True, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, + interior_subsample_step=args.interior_subsample_step, + boundary_subsample_step=args.boundary_subsample_step, batch_size=args.batch_size, num_workers=args.num_workers, ) @@ -242,15 +313,13 @@ def main(): device_name = "cpu" # Load model parameters Use new args for model - model_class = MODELS[args.model] - if args.load: - model = model_class.load_from_checkpoint(args.load, args=args) - if args.restore_opt: - # Save for later - # Unclear if this works for multi-GPU - model.opt_state = torch.load(args.load)["optimizer_states"][0] - else: - model = model_class(args) + ModelClass = MODELS[args.model] + model = ModelClass( + args, + config=config, + datastore=datastore, + datastore_boundary=datastore_boundary, + ) if args.eval: prefix = f"eval-{args.eval}-" @@ -268,30 +337,33 @@ def main(): save_last=True, ) logger = pl.loggers.WandbLogger( - project=args.wandb_project, name=run_name, config=args + project=args.wandb_project, + name=run_name, + config=dict(training=vars(args), datastore=datastore._config), ) trainer = pl.Trainer( max_epochs=args.epochs, deterministic=True, strategy="ddp", accelerator=device_name, + num_nodes=args.num_nodes, logger=logger, log_every_n_steps=1, callbacks=[checkpoint_callback], check_val_every_n_epoch=args.val_interval, precision=args.precision, + num_sanity_val_steps=args.num_sanity_steps, ) # Only init once, on rank 0 only if trainer.global_rank == 0: utils.init_wandb_metrics( - logger, val_steps=args.val_steps_log + logger, val_steps=args.val_steps_to_log ) # Do after wandb.init - wandb.save(args.data_config) if args.eval: trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) else: - trainer.fit(model=model, datamodule=data_module) + trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) if __name__ == "__main__": diff --git a/neural_lam/utils.py b/neural_lam/utils.py index ed5656f3..ff62b1a2 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,5 +1,6 @@ # Standard library import os +import shutil # Third-party import cartopy.crs as ccrs @@ -35,29 +36,142 @@ def __len__(self): def __iter__(self): return (self[i] for i in range(len(self))) + def __itruediv__(self, other): + """Divide each element in list with other""" + return self.__imul__(1.0 / other) -def load_graph(graph_name, device="cpu"): + def __imul__(self, other): + """Multiply each element in list with other""" + for buffer_tensor in self: + buffer_tensor *= other + + return self + + +def zero_index_edge_index(edge_index): """ + Make both sender and receiver indices of edge_index start at 0 + """ + return edge_index - edge_index.min(dim=1, keepdim=True)[0] + + +def load_graph(graph_dir_path, device="cpu"): + """Load all tensors representing the graph from `graph_dir_path`. + + Needs the following files for all graphs: + - m2m_edge_index.pt + - g2m_edge_index.pt + - m2g_edge_index.pt + - m2m_features.pt + - g2m_features.pt + - m2g_features.pt + - m2m_node_features.pt + + And in addition for hierarchical graphs: + - mesh_up_edge_index.pt + - mesh_down_edge_index.pt + - mesh_up_features.pt + - mesh_down_features.pt + + Parameters + ---------- + graph_dir_path : str + Path to directory containing the graph files. + device : str + Device to load tensors to. + + Returns + ------- + hierarchical : bool + Whether the graph is hierarchical. + graph : dict + Dictionary containing the graph tensors, with keys as follows: + - g2m_edge_index + - m2g_edge_index + - m2m_edge_index + - mesh_up_edge_index + - mesh_down_edge_index + - g2m_features + - m2g_features + - m2m_node_features + - mesh_up_features + - mesh_down_features + - mesh_static_features + + Load all tensors representing the graph """ - # Define helper lambda function - graph_dir_path = os.path.join("graphs", graph_name) def loads_file(fn): - return torch.load(os.path.join(graph_dir_path, fn), map_location=device) + return torch.load( + os.path.join(graph_dir_path, fn), + map_location=device, + weights_only=True, + ) + + # Load static node features + mesh_static_features = loads_file( + "m2m_node_features.pt" + ) # List of (N_mesh[l], d_mesh_static) # Load edges (edge_index) m2m_edge_index = BufferList( - loads_file("m2m_edge_index.pt"), persistent=False + [zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")], + persistent=False, ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) + # Change first indices to 0 + # m2g and g2m has to be handled specially as not all mesh nodes + # might be indexed + m2g_min_indices = m2g_edge_index.min(dim=1, keepdim=True)[0] + if m2g_min_indices[0] < m2g_min_indices[1]: + # mesh has the first indices + # Number of mesh nodes at level that connects to grid + num_mesh_nodes = mesh_static_features[0].shape[0] + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0], + m2g_edge_index[1] - num_mesh_nodes, + ), + dim=0, + ) + g2m_edge_index = torch.stack( + ( + g2m_edge_index[0] - num_mesh_nodes, + g2m_edge_index[1], + ), + dim=0, + ) + else: + # grid (interior) has the first indices + # NOTE: Below works, but would be good with a better way to get this + num_interior_nodes = m2g_edge_index[1].max() + 1 + num_grid_nodes = g2m_edge_index[0].max() + 1 + + m2g_edge_index = torch.stack( + ( + m2g_edge_index[0] - num_interior_nodes, + m2g_edge_index[1], + ), + dim=0, + ) + g2m_edge_index = torch.stack( + ( + g2m_edge_index[0], + g2m_edge_index[1] - num_grid_nodes, + ), + dim=0, + ) + n_levels = len(m2m_edge_index) hierarchical = n_levels > 1 # Nor just single level mesh graph # Load static edge features - m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f) + # List of (M_m2m[l], d_edge_f) + m2m_features = loads_file("m2m_features.pt") g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f) m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) @@ -65,18 +179,11 @@ def loads_file(fn): longest_edge = max( torch.max(level_features[:, 0]) for level_features in m2m_features ) # Col. 0 is length - m2m_features = BufferList( - [level_features / longest_edge for level_features in m2m_features], - persistent=False, - ) + m2m_features = BufferList(m2m_features, persistent=False) + m2m_features /= longest_edge g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge - # Load static node features - mesh_static_features = loads_file( - "mesh_features.pt" - ) # List of (N_mesh[l], d_mesh_static) - # Some checks for consistency assert ( len(m2m_features) == n_levels @@ -88,10 +195,18 @@ def loads_file(fn): if hierarchical: # Load up and down edges and features mesh_up_edge_index = BufferList( - loads_file("mesh_up_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_up_edge_index.pt") + ], + persistent=False, ) # List of (2, M_up[l]) mesh_down_edge_index = BufferList( - loads_file("mesh_down_edge_index.pt"), persistent=False + [ + zero_index_edge_index(ei) + for ei in loads_file("mesh_down_edge_index.pt") + ], + persistent=False, ) # List of (2, M_down[l]) mesh_up_features = loads_file( @@ -102,20 +217,10 @@ def loads_file(fn): ) # List of (M_down[l], d_edge_f) # Rescale - mesh_up_features = BufferList( - [ - edge_features / longest_edge - for edge_features in mesh_up_features - ], - persistent=False, - ) - mesh_down_features = BufferList( - [ - edge_features / longest_edge - for edge_features in mesh_down_features - ], - persistent=False, - ) + mesh_up_features = BufferList(mesh_up_features, persistent=False) + mesh_up_features /= longest_edge + mesh_down_features = BufferList(mesh_down_features, persistent=False) + mesh_down_features /= longest_edge mesh_static_features = BufferList( mesh_static_features, persistent=False @@ -179,7 +284,11 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + # If latex is not available, some visualizations might not render + # correctly, but will at least not raise an error. Alternatively, use + # unicode raised numbers. + usetex = True if shutil.which("latex") else False + bundle = bundles.neurips2023(usetex=usetex, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( @@ -199,194 +308,263 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric(f"val_loss_unroll{step}", summary="min") -class ConfigLoader: +def get_stacked_lat_lons(datastore, datastore_boundary=None): + """ + Stack the lat-lon coordinates of all grid nodes in the correct ordering + + Parameters + ---------- + datastore : BaseDatastore + The datastore containing data for the interior region of the grid + datastore_boundary : BaseDatastore or None + (Optional) The datastore containing data for boundary forcing + + Returns + ------- + stacked_coords : np.ndarray + Array of all coordinates, shaped (num_total_grid_nodes, 2) """ - Class for loading configuration files. + grid_coords = datastore.get_lat_lon(category="state") - This class loads a YAML configuration file and provides a way to access - its values as attributes. + if datastore_boundary is None: + return grid_coords + + # Append boundary forcing positions last + boundary_coords = datastore_boundary.get_lat_lon(category="forcing") + return np.concatenate((grid_coords, boundary_coords), axis=0) + + +def get_stacked_xy(datastore, datastore_boundary=None): """ + Stack the xy coordinates of all grid nodes in the correct ordering, + with xy coordinates being in the CRS of the datastore + + Parameters + ---------- + datastore : BaseDatastore + The datastore containing data for the interior region of the grid + datastore_boundary : BaseDatastore or None + (Optional) The datastore containing data for boundary forcing + + Returns + ------- + stacked_coords : np.ndarray + Array of all coordinates, shaped (num_total_grid_nodes, 2) + """ + lat_lons = get_stacked_lat_lons(datastore, datastore_boundary) - def __init__(self, config_path, values=None): - self.config_path = config_path - if values is None: - self.values = self.load_config() - else: - self.values = values - - def load_config(self): - """Load configuration file.""" - with open(self.config_path, encoding="utf-8", mode="r") as file: - return yaml.safe_load(file) - - def __getattr__(self, name): - keys = name.split(".") - value = self.values - for key in keys: - if key in value: - value = value[key] - else: - return None - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value + # transform to datastore CRS + xyz = datastore.coords_projection.transform_points( + ccrs.PlateCarree(), lat_lons[:, 0], lat_lons[:, 1] + ) + return xyz[:, :2] - def __getitem__(self, key): - value = self.values[key] - if isinstance(value, dict): - return ConfigLoader(None, values=value) - return value - - def __contains__(self, key): - return key in self.values - - def param_names(self): - """Return parameter names.""" - surface_names = self.values["state"]["surface"] - atmosphere_names = [ - f"{var}_{level}" - for var in self.values["state"]["atmosphere"] - for level in self.values["state"]["levels"] - ] - return surface_names + atmosphere_names - - def param_units(self): - """Return parameter units.""" - surface_units = self.values["state"]["surface_units"] - atmosphere_units = [ - unit - for unit in self.values["state"]["atmosphere_units"] - for _ in self.values["state"]["levels"] - ] - return surface_units + atmosphere_units - - def num_data_vars(self, key): - """Return the number of data variables for a given key.""" - surface_vars = len(self.values[key]["surface"]) - atmosphere_vars = len(self.values[key]["atmosphere"]) - levels = len(self.values[key]["levels"]) - return surface_vars + atmosphere_vars * levels - - def projection(self): - """Return the projection.""" - proj_config = self.values["projections"]["class"] - proj_class = getattr(ccrs, proj_config["proj_class"]) - proj_params = proj_config["proj_params"] - return proj_class(**proj_params) - - def open_zarr(self, dataset_name): - """Open a dataset specified by the dataset name.""" - dataset_path = self.zarrs[dataset_name].path - if dataset_path is None or not os.path.exists(dataset_path): - print(f"Dataset '{dataset_name}' not found at path: {dataset_path}") - return None - dataset = xr.open_zarr(dataset_path, consolidated=True) - return dataset - - def load_normalization_stats(self): - """Load normalization statistics from Zarr archive.""" - normalization_path = self.normalization.zarr - if not os.path.exists(normalization_path): - print( - f"Normalization statistics not found at " - f"path: {normalization_path}" - ) - return None - normalization_stats = xr.open_zarr( - normalization_path, consolidated=True + +def get_time_step(times): + """Calculate the time step from a time dataarray. + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + + Returns + ------- + time_step : float + The time step in the the datetime-format of the times dataarray. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" ) - return normalization_stats + return time_diffs[0] + + +def check_time_overlap( + da1, + da2, + da1_is_forecast=False, + da2_is_forecast=False, + num_past_steps=1, + num_future_steps=1, +): + """Check that the time coverage of two dataarrays overlap. + + Parameters + ---------- + da1 : xr.DataArray + The first dataarray to check. + da2 : xr.DataArray + The second dataarray to check. + da1_is_forecast : bool, optional + Whether the first dataarray is forecast data. + da2_is_forecast : bool, optional + Whether the second dataarray is forecast data. + num_past_steps : int, optional + Number of past forcing steps. + num_future_steps : int, optional + Number of future forcing steps. + + Raises + ------ + ValueError + If the time coverage of the dataarrays does not overlap. + """ - def process_dataset(self, dataset_name, split="train", stack=True): - """Process a single dataset specified by the dataset name.""" + if da1_is_forecast: + times_da1 = da1.analysis_time + else: + times_da1 = da1.time + time_min_da1 = times_da1.min().values + time_max_da1 = times_da1.max().values - dataset = self.open_zarr(dataset_name) - if dataset is None: - return None + if da2_is_forecast: + times_da2 = da2.analysis_time + _ = get_time_step(da2.elapsed_forecast_duration) + else: + times_da2 = da2.time + time_step_da2 = get_time_step(times_da2.values) + + time_min_da2 = times_da2.min().values + time_max_da2 = times_da2.max().values - start, end = ( - self.splits[split].start, - self.splits[split].end, + # Calculate required bounds for da2 using its time step + da2_required_time_min = time_min_da1 - num_past_steps * time_step_da2 + da2_required_time_max = time_max_da1 + num_future_steps * time_step_da2 + + if time_min_da2 > da2_required_time_min: + raise ValueError( + f"The second DataArray (e.g. 'boundary forcing') starts too late." + f"Required start: {da2_required_time_min}, " + f"but DataArray starts at {time_min_da2}." ) - dataset = dataset.sel(time=slice(start, end)) - dataset = dataset.rename_dims( - { - v: k - for k, v in self.zarrs[dataset_name].dims.values.items() - if k not in dataset.dims - } + + if time_max_da2 < da2_required_time_max: + raise ValueError( + f"The second DataArray (e.g. 'boundary forcing') ends too early." + f"Required end: {da2_required_time_max}, " + f"but DataArray ends at {time_max_da2}." ) - vars_surface = [] - if self[dataset_name].surface: - vars_surface = dataset[self[dataset_name].surface] - - vars_atmosphere = [] - if self[dataset_name].atmosphere: - vars_atmosphere = xr.merge( - [ - dataset[var] - .sel(level=level, drop=True) - .rename(f"{var}_{level}") - for var in self[dataset_name].atmosphere - for level in self[dataset_name].levels - ] - ) - if vars_surface and vars_atmosphere: - dataset = xr.merge([vars_surface, vars_atmosphere]) - elif vars_surface: - dataset = vars_surface - elif vars_atmosphere: - dataset = vars_atmosphere - else: - print(f"No variables found in dataset {dataset_name}") - return None - - lat_name = self.zarrs[dataset_name].lat_lon_names.lat - lon_name = self.zarrs[dataset_name].lat_lon_names.lon - if dataset[lat_name].ndim == 2: - dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True) - if dataset[lon_name].ndim == 2: - dataset[lon_name] = dataset[lon_name].isel(y=0, drop=True) - - if "x" in dataset.dims: - dataset = dataset.rename({"x": "old_x"}) - if "y" in dataset.dims: - dataset = dataset.rename({"y": "old_y"}) - dataset = dataset.assign_coords( - x=dataset[lon_name], y=dataset[lat_name] +def crop_time_if_needed( + da1, + da2, + da1_is_forecast=False, + da2_is_forecast=False, + num_past_steps=1, + num_future_steps=1, +): + """ + Slice away the first few timesteps from the first DataArray (e.g. 'state') + if the second DataArray (e.g. boundary forcing) does not cover that range + (including num_past_steps). + + Parameters + ---------- + da1 : xr.DataArray + The first DataArray to crop. + da2 : xr.DataArray + The second DataArray to compare against. + da1_is_forecast : bool, optional + Whether the first dataarray is forecast data. + da2_is_forecast : bool, optional + Whether the second dataarray is forecast data. + num_past_steps : int + Number of past time steps to consider. + num_future_steps : int + Number of future time steps to consider. + + Return + ------ + da1 : xr.DataArray + The cropped first DataArray and print a warning if any steps are + removed. + """ + if da1 is None or da2 is None: + return da1 + + try: + check_time_overlap( + da1, + da2, + da1_is_forecast, + da2_is_forecast, + num_past_steps, + num_future_steps, ) - dataset["x"] = dataset[lon_name] - dataset["y"] = dataset[lat_name] + return da1 + except ValueError: + # If da2 coverage is insufficient, remove earliest da1 times + # until coverage is possible. Figure out how many steps to remove. + if da1_is_forecast: + da1_tvals = da1.analysis_time.values + else: + da1_tvals = da1.time.values + if da2_is_forecast: + da2_tvals = da2.analysis_time.values + else: + da2_tvals = da2.time.values + + # Calculate how many steps we would have to remove + if da2_is_forecast: + # The windowing for forecast type data happens in the + # elapsed_forecast_duration dimension, so we can omit it here. + required_min = da2_tvals[0] + required_max = da2_tvals[-1] + else: + dt = get_time_step(da2_tvals) + required_min = da2_tvals[0] + num_past_steps * dt + required_max = da2_tvals[-1] - num_future_steps * dt + + # Calculate how many steps to remove at beginning and end + first_valid_idx = (da1_tvals >= required_min).argmax() + n_removed_begin = first_valid_idx + last_valid_idx_plus_one = ( + da1_tvals > required_max + ).argmax() # To use for slice + n_removed_begin = first_valid_idx + n_removed_end = len(da1_tvals) - last_valid_idx_plus_one + if n_removed_begin > 0 or n_removed_end > 0: + print( + f"Warning: cropping da1 (e.g. 'state') to align with da2 " + f"(e.g. 'boundary forcing'). Removed {n_removed_begin} steps " + f"at start of data interval and {n_removed_end} at the end." + ) + da1 = da1.isel(time=slice(first_valid_idx, last_valid_idx_plus_one)) + return da1 - if stack: - dataset = self.stack_grid(dataset) - return dataset +def inverse_softplus(x, beta=1, threshold=20): + """ + Inverse of torch.nn.functional.softplus - def stack_grid(self, dataset): - """Stack grid dimensions.""" - dataset = dataset.squeeze().stack(grid=("x", "y")).to_array() + For x*beta above threshold, returns linear function for numerical + stability. - if "time" in dataset.dims: - dataset = dataset.transpose("time", "grid", "variable") - else: - dataset = dataset.transpose("grid", "variable") - return dataset - - def get_nwp_xy(self, dataset_name): - """Get the longitude and latitude coordinates for the NWP grid.""" - x = np.sort(self.process_dataset(dataset_name, stack=False).x.values) - y = np.sort(self.process_dataset(dataset_name, stack=False).y.values) - xx, yy = np.meshgrid(x, y) - xy = np.stack((xx.T, yy.T), axis=0) - - return xy - - def get_step_length(self): - """Get the temporal resolution for a given dataset.""" - times = self.open_zarr("state").isel(time=slice(0, 2)).time.values - step_length = times[1] - times[0] - step_length_hours = step_length.astype("timedelta64[h]").astype(int) - return step_length_hours + Input is clamped to x > ln(1+1e-6)/beta which is approximately positive + values of x. + Note that this torch.clamp_min will make gradients 0, but this is not a + problem as values of x that are this close to 0 have gradients of 0 anyhow. + """ + non_linear_part = ( + torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta + ) + x = torch.where(x * beta <= threshold, non_linear_part, x) + + return x + + +def inverse_sigmoid(x): + """ + Inverse of torch.sigmoid + + Sigmoid output takes values in [0,1], this makes sure input is just within + this interval. + Note that this torch.clamp will make gradients 0, but this is not a problem + as values of x that are this close to 0 or 1 have gradients of 0 anyhow. + """ + x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) + return torch.log(x_clamped / (1 - x_clamped)) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 8c36a9a7..044afa7a 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,13 +2,15 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr -# First-party -from neural_lam import utils +# Local +from . import utils +from .datastore.base import BaseRegularGridDatastore @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, data_config, title=None, step_length=3): +def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): """ Plot a heatmap of errors of different variables at different predictions horizons @@ -16,6 +18,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3): """ errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) d_f, pred_steps = errors_np.shape + step_length = datastore.step_length # Normalize all errors to [0,1] for color map max_errors = errors_np.max(axis=1) # d_f @@ -48,11 +51,10 @@ def plot_error_map(errors, data_config, title=None, step_length=3): ax.set_xlabel("Lead time (h)", size=label_size) ax.set_yticks(np.arange(d_f)) + var_names = datastore.get_vars_names(category="state") + var_units = datastore.get_vars_units(category="state") y_ticklabels = [ - f"{name} ({unit})" - for name, unit in zip( - data_config.param_names(), data_config.param_units() - ) + f"{name} ({unit})" for name, unit in zip(var_names, var_units) ] ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -62,52 +64,82 @@ def plot_error_map(errors, data_config, title=None, step_length=3): return fig +def plot_on_axis( + ax, + da, + datastore, + obs_mask=None, + vmin=None, + vmax=None, + ax_title=None, + cmap="plasma", + grid_limits=None, +): + """ + Plot weather state on given axis + """ + ax.coastlines() # Add coastline outlines + + extent = datastore.get_xy_extent("state") + + im = da.plot.imshow( + ax=ax, + origin="lower", + x="x", + extent=extent, + vmin=vmin, + vmax=vmax, + cmap=cmap, + transform=datastore.coords_projection, + ) + + if ax_title: + ax.set_title(ax_title, size=15) + + return im + + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, target, obs_mask, data_config, title=None, vrange=None + datastore: BaseRegularGridDatastore, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, + title=None, + vrange=None, ): """ Plot example prediction and grond truth. + Each has shape (N_grid,) + """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange - # Set up masking of border region - mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, axes = plt.subplots( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.projection()}, + subplot_kw={"projection": datastore.coords_projection}, ) # Plot pred and target - for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() - im = ax.imshow( - data_grid, - origin="lower", - alpha=pixel_alpha, + for ax, da in zip(axes, (da_target, da_prediction)): + plot_on_axis( + ax, + da, + datastore, vmin=vmin, vmax=vmax, - cmap="plasma", ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -116,7 +148,9 @@ def plot_prediction( @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): +def plot_spatial_error( + error, datastore: BaseRegularGridDatastore, title=None, vrange=None +): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -128,23 +162,25 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): else: vmin, vmax = vrange - # Set up masking of border region - mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state) - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region - fig, ax = plt.subplots( - figsize=(5, 4.8), subplot_kw={"projection": data_config.projection()} + figsize=(5, 4.8), + subplot_kw={"projection": datastore.coords_projection}, ) - ax.coastlines() # Add coastline outlines - error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy() + error_grid = ( + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() + .numpy() + ) + extent = datastore.get_xy_extent("state") + # TODO: This needs to be converted to DA and use plot_on_axis im = ax.imshow( error_grid, origin="lower", - alpha=pixel_alpha, + extent=extent, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4b5da0a8..9bd2067a 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,124 +1,944 @@ +# Standard library +import datetime +import warnings +from typing import Union + # Third-party +import numpy as np import pytorch_lightning as pl import torch +import xarray as xr # First-party -from neural_lam import utils +from neural_lam.datastore.base import BaseDatastore +from neural_lam.utils import ( + check_time_overlap, + crop_time_if_needed, + get_time_step, +) class WeatherDataset(torch.utils.data.Dataset): - """ - Dataset class for weather data. + """Dataset class for weather data. + + This class loads and processes weather data from a given datastore. - This class loads and processes weather data from zarr files based on the - provided configuration. It supports splitting the data into train, - validation, and test sets. + Parameters + ---------- + datastore : BaseDatastore + The datastore to load the data from. + datastore_boundary : BaseDatastore + The boundary datastore to load the data from. + split : str, optional + The data split to use ("train", "val" or "test"). Default is "train". + ar_steps : int, optional + The number of autoregressive steps. Default is 3. + num_past_forcing_steps: int, optional + Number of past time steps to include in forcing input. If set to i, + forcing from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as forcing inputs at time t + Default is 1. + num_future_forcing_steps: int, optional + Number of future time steps to include in forcing input. If set to j, + forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before + t, given num_past_forcing_steps) are included as forcing inputs at time + t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. + interior_subsample_step : int, optional + The stride/step size used when sampling interior domain data points. A + value of N means only every Nth point will be sampled in the temporal + dimension. For example, if step_length=3 hours and + interior_subsample_step=2, data will be sampled every 6 hours. Default + is 1 (use every timestep). + boundary_subsample_step : int, optional + The stride/step size used when sampling boundary condition data points. + A value of N means only every Nth point will be sampled in the temporal + dimension. For example, if step_length=3 hours and + boundary_subsample_step=2, boundary conditions will be sampled every 6 + hours. Default is 1 (use every timestep). + standardize : bool, optional + Whether to standardize the data. Default is True. """ + # The current implementation requires at least 2 time steps for the + # initial state (see GraphCast). + INIT_STEPS = 2 # Number of initial state steps needed + def __init__( self, + datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, - batch_size=4, - control_only=False, - data_config="neural_lam/data_config.yaml", + num_past_forcing_steps=1, + num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, + interior_subsample_step=1, + boundary_subsample_step=1, + standardize=True, ): super().__init__() - assert split in ( - "train", - "val", - "test", - ), "Unknown dataset split" - self.split = split - self.batch_size = batch_size self.ar_steps = ar_steps - self.control_only = control_only - self.config_loader = utils.ConfigLoader(data_config) + self.datastore = datastore + self.datastore_boundary = datastore_boundary + self.num_past_forcing_steps = num_past_forcing_steps + self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps + self.interior_subsample_step = interior_subsample_step + self.boundary_subsample_step = boundary_subsample_step + # Scale forcing steps based on subsampling + self.effective_past_forcing_steps = ( + num_past_forcing_steps * interior_subsample_step + ) + self.effective_future_forcing_steps = ( + num_future_forcing_steps * interior_subsample_step + ) + self.effective_past_boundary_steps = ( + num_past_boundary_steps * boundary_subsample_step + ) + self.effective_future_boundary_steps = ( + num_future_boundary_steps * boundary_subsample_step + ) - self.state = self.config_loader.process_dataset("state", self.split) - assert self.state is not None, "State dataset not found" - self.forcing = self.config_loader.process_dataset("forcing", self.split) - self.boundary = self.config_loader.process_dataset( - "boundary", self.split + # Validate subsample steps + if ( + not isinstance(interior_subsample_step, int) + or interior_subsample_step < 1 + ): + raise ValueError( + "interior_subsample_step must be a positive integer" + ) + if ( + not isinstance(boundary_subsample_step, int) + or boundary_subsample_step < 1 + ): + raise ValueError( + "boundary_subsample_step must be a positive integer" + ) + + self.da_state = self.datastore.get_dataarray( + category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) + self.da_forcing = self.datastore.get_dataarray( + category="forcing", split=self.split + ) + # XXX For now boundary data is always considered mdp-forcing data + if self.datastore_boundary is not None: + self.da_boundary_forcing = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary_forcing = None + + # check that with the provided data-arrays and ar_steps that we have a + # non-zero amount of samples + if self.__len__() <= 0: + raise ValueError( + "The provided datastore only provides " + f"{len(self.da_state.time)} total time steps, which is too few " + "to create a single sample for the WeatherDataset " + f"configuration used in the `{split}` split. You could try " + "either reducing the number of autoregressive steps " + "(`ar_steps`) and/or the forcing window size " + "(`num_past_forcing_steps` and `num_future_forcing_steps`)" + ) - self.state_times = self.state.time.values - self.forcing_window = self.config_loader.forcing.window - self.boundary_window = self.config_loader.boundary.window + # Check the dimensions and their ordering + parts = dict(state=self.da_state) + if self.da_forcing is not None: + parts["forcing"] = self.da_forcing - if self.forcing is not None: - self.forcing_windowed = ( - self.forcing.sel( - time=self.state.time, - method="nearest", + for part, da in parts.items(): + expected_dim_order = self.datastore.expected_dim_order( + category=part + ) + if da.dims != expected_dim_order: + raise ValueError( + f"The dimension order of the `{part}` data ({da.dims}) " + f"does not match the expected dimension order " + f"({expected_dim_order}). Maybe you forgot to transpose " + "the data in `BaseDatastore.get_dataarray`?" ) - .pad( - time=(self.forcing_window // 2, self.forcing_window // 2), - mode="edge", + + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + + # Check time step consistency in state data and determine time steps + # for state, forcing and boundary forcing data + # STATE + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + self.forecast_step_state = get_time_step( + self.da_state.elapsed_forecast_duration + ) + else: + state_times = self.da_state.time + self.time_step_state = get_time_step(state_times) + # FORCING + if self.da_forcing is not None: + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + self.forecast_step_forcing = get_time_step( + self.da_forcing.elapsed_forecast_duration ) - .rolling(time=self.forcing_window, center=True) - .construct("window") + else: + forcing_times = self.da_forcing.time + self.time_step_forcing = get_time_step(forcing_times.values) + # inform user about the original and the subsampled time step + if self.interior_subsample_step != 1: + print( + f"Subsampling interior data with step size " + f"{self.interior_subsample_step} from original time step " + f"{self.time_step_state}" ) + else: + print(f"Using original time step {self.time_step_state} for data") - if self.boundary is not None: - self.boundary_windowed = ( - self.boundary.sel( - time=self.state.time, - method="nearest", + # BOUNDARY FORCING + if self.da_boundary_forcing is not None: + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary_forcing.analysis_time + self.forecast_step_boundary = get_time_step( + self.da_boundary_forcing.elapsed_forecast_duration ) - .pad( - time=(self.boundary_window // 2, self.boundary_window // 2), - mode="edge", + else: + boundary_times = self.da_boundary_forcing.time + self.time_step_boundary = get_time_step(boundary_times.values) + + if self.boundary_subsample_step != 1: + print( + f"Subsampling boundary data with step size " + f"{self.boundary_subsample_step} from original time step " + f"{self.time_step_boundary}" + ) + else: + print( + f"Using original time step {self.time_step_boundary} for " + "boundary data" ) - .rolling(time=self.boundary_window, center=True) - .construct("window") + + # Forcing data is part of the same datastore as state data. During + # creation, the time dimension of the forcing data is matched to the + # state data. + # Boundary data is part of a separate datastore The boundary data is + # allowed to have a different time_step Checks that the boundary data + # covers the required time range is required. + # Crop interior data if boundary coverage is insufficient + if self.da_boundary_forcing is not None: + self.da_state = crop_time_if_needed( + self.da_state, + self.da_boundary_forcing, + da1_is_forecast=self.datastore.is_forecast, + da2_is_forecast=self.datastore_boundary.is_forecast, + num_past_steps=self.num_past_boundary_steps, + num_future_steps=self.num_future_boundary_steps, + ) + + # Now do final overlap check and possibly raise errors if still invalid + if self.da_boundary_forcing is not None: + check_time_overlap( + self.da_state, + self.da_boundary_forcing, + da1_is_forecast=self.datastore.is_forecast, + da2_is_forecast=self.datastore_boundary.is_forecast, + num_past_steps=self.num_past_boundary_steps, + num_future_steps=self.num_future_boundary_steps, ) + # Set up for standardization + # TODO: This will become part of ar_model.py soon! + self.standardize = standardize + if standardize: + self.ds_state_stats = self.datastore.get_standardization_dataarray( + category="state" + ) + + self.da_state_mean = self.ds_state_stats.state_mean + self.da_state_std = self.ds_state_stats.state_std + + if self.da_forcing is not None: + self.ds_forcing_stats = ( + self.datastore.get_standardization_dataarray( + category="forcing" + ) + ) + self.da_forcing_mean = self.ds_forcing_stats.forcing_mean + self.da_forcing_std = self.ds_forcing_stats.forcing_std + + # XXX: Again, the boundary data is considered forcing data for now + if self.da_boundary_forcing is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="forcing" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std + def __len__(self): - # Skip first and last time step - return len(self.state.time) - self.ar_steps + if self.datastore.is_ensemble: + warnings.warn( + "only using first ensemble member, so dataset size is " + " effectively reduced by the number of ensemble members " + f"({self.datastore.num_ensemble_members})", + UserWarning, + ) - def __getitem__(self, idx): - sample = torch.tensor( - self.state.isel(time=slice(idx, idx + self.ar_steps)).values, - dtype=torch.float32, - ) + if self.datastore.is_forecast: + # for now we simply create a single sample for each analysis time + # and then take the first (2 + ar_steps) forecast times. In + # addition we only use the first ensemble member (if ensemble data + # has been provided). + # This means that for each analysis time we get a single sample + # check that there are enough forecast steps available to create + # samples given the number of autoregressive steps requested + required_steps = self.INIT_STEPS + self.ar_steps + required_span = (required_steps - 1) * self.interior_subsample_step + + # Calculate available forecast steps + n_forecast_steps = len(self.da_state.elapsed_forecast_duration) + + if n_forecast_steps < required_span: + raise ValueError( + f"Not enough forecast steps ({n_forecast_steps}) for " + f"required span of {required_span} steps with " + f"subsample_step={self.interior_subsample_step}" + ) + + return self.da_state.analysis_time.size + else: + # Calculate the number of samples in the dataset as: + # total_samples = total_timesteps - required_time_span - + # required_past_steps - effective_future_forcing_steps + # Where: + # - total_timesteps: total number of timesteps in the state data + # - required_time_span: number of continuous timesteps needed for + # initial state + autoregressive steps, accounting for subsampling + # - required_past_steps: additional past timesteps needed for + # forcing data beyond initial state + # - effective_future_forcing_steps: number of future timesteps + # needed for forcing data with subsampling + required_continuous_steps = self.INIT_STEPS + self.ar_steps + required_time_span = ( + required_continuous_steps * self.interior_subsample_step + ) + required_past_steps = max( + 0, + self.effective_past_forcing_steps + - self.INIT_STEPS * self.interior_subsample_step, + ) - forcing = ( - self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps)) - .stack(variable_window=("variable", "window")) - .values - if self.forcing is not None - else torch.tensor([]) + return ( + len(self.da_state.time) + - required_time_span + - required_past_steps + - self.effective_future_forcing_steps + ) + + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing=None, + num_past_steps=None, + num_future_steps=None, + is_boundary=False, + ): + """ + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing`. For the state data, slicing is done based on `idx`. For + the forcing/boundary data, nearest neighbor matching is performed based + on the state times (assuming constant timestep size). Additionally, the + time deltas between the matched forcing/boundary times and state times + (in multiples of state time steps) is added to the forcing dataarray. + This will be used as an additional input feature in the model (as + temporal embedding). + + Parameters + ---------- + da_state : xr.DataArray + The state dataarray to slice. + idx : int + The index of the time step to start the sample from in the state + data. + n_steps : int + The number of time steps to include in the sample. + da_forcing : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. + is_boundary : bool, optional + Whether the data is boundary data. Default is `False`. + + Returns + ------- + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', + 'state_feature'). + da_forcing_matched : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. + """ + init_steps = self.INIT_STEPS + subsample_step = ( + self.boundary_subsample_step + if is_boundary + else self.interior_subsample_step ) + # slice the dataarray to include the required number of time steps + if self.datastore.is_forecast: + # this implies that the data will have both `analysis_time` and + # `elapsed_forecast_duration` dimensions for forecasts. We for now + # simply select a analysis time and the first `n_steps` forecast + # times (given no offset). Note that this means that we get one + # sample per forecast, always starting at forecast time 2. - boundary = ( - self.boundary_windowed.isel( - time=slice(idx + 2, idx + self.ar_steps) + # Calculate base offset and indices with subsampling + offset = ( + max(0, num_past_steps - init_steps) if num_past_steps else 0 ) - .stack(variable_window=("variable", "window")) - .values - if self.boundary is not None - else torch.tensor([]) + + # Calculate initial and target indices + init_indices = [ + offset + i * subsample_step for i in range(init_steps) + ] + target_indices = [ + offset + (init_steps + i) * subsample_step + for i in range(n_steps) + ] + all_indices = init_indices + target_indices + + da_state_sliced = da_state.isel( + analysis_time=idx, + elapsed_forecast_duration=all_indices, + ) + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration + ) + da_state_sliced = da_state_sliced.swap_dims( + {"elapsed_forecast_duration": "time"} + ) + + else: + # Analysis data slicing, already correctly modified + start_idx = idx + ( + max(0, num_past_steps - init_steps) if num_past_steps else 0 + ) + all_indices = [ + start_idx + i * subsample_step + for i in range(init_steps + n_steps) + ] + da_state_sliced = da_state.isel(time=all_indices) + + if da_forcing is None: + return da_state_sliced, None + + # Get the state times and its temporal resolution for matching with + # forcing data. + state_times = da_state_sliced["time"] + da_list = [] + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary + if "analysis_time" in da_forcing.dims: + # For forecast data with analysis_time and elapsed_forecast_duration + # Select the closest analysis_time in the past in the + # forcing/boundary data + offset = max(0, num_past_steps - init_steps) + state_time = state_times[init_steps].values + forcing_analysis_time_idx = da_forcing.analysis_time.get_index( + "analysis_time" + ).get_indexer([state_time], method="pad")[0] + + # Adjust window indices for subsampled steps + for step_idx in range(init_steps, len(state_times)): + window_start = ( + offset + + step_idx * subsample_step + - num_past_steps * subsample_step + ) + window_end = ( + offset + + step_idx * subsample_step + + (num_future_steps + 1) * subsample_step + ) + + current_time = ( + forcing_analysis_time_idx + + da_forcing.elapsed_forecast_duration[ + step_idx * subsample_step + ] + ) + + da_sliced = da_forcing.isel( + analysis_time=forcing_analysis_time_idx, + elapsed_forecast_duration=slice( + window_start, window_end, subsample_step + ), + ) + da_sliced = da_sliced.rename( + {"elapsed_forecast_duration": "window"} + ) + + # Assign the 'window' coordinate to be relative positions + da_sliced = da_sliced.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) + # Calculate window time deltas for forecast data + window_time_deltas = ( + da_forcing.elapsed_forecast_duration[ + window_start:window_end:subsample_step + ].values + - da_forcing.elapsed_forecast_duration[ + step_idx * subsample_step + ].values + ) + # Assign window time delta coordinate + da_sliced["window_time_deltas"] = ("window", window_time_deltas) + + da_sliced = da_sliced.expand_dims( + dim={"time": [current_time.values]} + ) + + da_list.append(da_sliced) + + else: + for idx_time in range(init_steps, len(state_times)): + state_time = state_times[idx_time].values + + # Select the closest time in the past from forcing data using + # sel with method="pad" + forcing_time_idx = da_forcing.time.get_index( + "time" + ).get_indexer([state_time], method="pad")[0] + + window_start = ( + forcing_time_idx - num_past_steps * subsample_step + ) + window_end = ( + forcing_time_idx + (num_future_steps + 1) * subsample_step + ) + + da_window = da_forcing.isel( + time=slice(window_start, window_end, subsample_step) + ) + + # Rename the time dimension to window for consistency + da_window = da_window.rename({"time": "window"}) + + # Assign the 'window' coordinate to be relative positions + da_window = da_window.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) + ) + + # Calculate window time deltas for analysis data + window_time_deltas = ( + da_forcing.time[ + window_start:window_end:subsample_step + ].values + - da_forcing.time[forcing_time_idx].values + ) + da_window["window_time_deltas"] = ("window", window_time_deltas) + + da_window = da_window.expand_dims(dim={"time": [state_time]}) + + da_list.append(da_window) + + da_forcing_matched = xr.concat(da_list, dim="time") + + return da_state_sliced, da_forcing_matched + + def _process_windowed_data( + self, da_windowed, da_state, da_target_times, add_time_deltas=True + ): + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + deltas to the existing features. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + add_time_deltas : bool + If time deltas to each window position should be concatenated + as features + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + window_size = da_windowed.window.size + # Stack the 'feature' and 'window' dimensions and add the + # time deltas to the existing features + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + if add_time_deltas: + # Add the time deltas a new feature to the windowed + # data, as a multiple of the state time step + time_deltas = ( + da_windowed["window_time_deltas"].isel( + forcing_feature_windowed=slice(0, window_size) + ) + / self.time_step_state + ) + # All data variables share the same time deltas + da_windowed = xr.concat( + [da_windowed, time_deltas], + dim="forcing_feature_windowed", + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + da_windowed = xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + return da_windowed + + def _build_item_dataarrays(self, idx): + """ + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. + + Parameters + ---------- + idx : int + The index of the sample to create the dataarrays for. + + Returns + ------- + da_init_states : xr.DataArray + The dataarray for the initial states. + da_target_states : xr.DataArray + The dataarray for the target states. + da_forcing_windowed : xr.DataArray + The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. + da_target_times : xr.DataArray + The dataarray for the target times. + """ + da_state = self.da_state + if self.da_forcing is not None: + if "ensemble_member" in self.da_forcing.dims: + raise NotImplementedError( + "Ensemble member not yet supported for forcing data" + ) + da_forcing = self.da_forcing + else: + da_forcing = None + + if self.da_boundary_forcing is not None: + da_boundary = self.da_boundary_forcing + else: + da_boundary = None + + # This function will return a slice of the state data and the forcing + # and boundary data (if provided) for one sample (idx). + # If da_forcing is None, the function will return None for + # da_forcing_windowed. + if da_boundary is not None: + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, + is_boundary=True, + ) + else: + da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, + ) + + # load the data into memory + da_state.load() + if da_forcing is not None: + da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() + + da_init_states = da_state.isel(time=slice(0, 2)) + da_target_states = da_state.isel(time=slice(2, None)) + da_target_times = da_target_states.time + + if self.standardize: + da_init_states = ( + da_init_states - self.da_state_mean + ) / self.da_state_std + da_target_states = ( + da_target_states - self.da_state_mean + ) / self.da_state_std + + if da_forcing is not None: + # XXX: Here we implicitly assume that the last dimension of the + # forcing data is the forcing feature dimension. To standardize + # on `.device` we need a different implementation. (e.g. a + # tensor with repeated means and stds for each "windowed" time.) + da_forcing_windowed = ( + da_forcing_windowed - self.da_forcing_mean + ) / self.da_forcing_std + + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + + # This function handles the stacking of the forcing and boundary data + # and adds the time deltas. It can handle `None` inputs for the forcing + # and boundary data (and simlpy return an empty DataArray in that case). + # We don't need time delta features for interior forcing, as these + # deltas are always the same. + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, + da_state, + da_target_times, + add_time_deltas=False, + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, + da_state, + da_target_times, + add_time_deltas=True, ) - init_states = sample[:2] - target_states = sample[2:] + return ( + da_init_states, + da_target_states, + da_forcing_windowed, + da_boundary_windowed, + da_target_times, + ) + + def __getitem__(self, idx): + """ + Return a single training sample, which consists of the initial states, + target states, forcing and batch times. + + The implementation currently uses xarray.DataArray objects for the + standardization (scaling to mean 0.0 and standard deviation of 1.0) so + that we can make us of xarray's broadcasting capabilities. This makes + it possible to standardization with both global means, but also for + example where a grid-point mean has been computed. This code will have + to be replace if standardization is to be done on the GPU to handle + different shapes of the standardization. + + Parameters + ---------- + idx : int + The index of the sample to return, this will refer to the time of + the initial state. + + Returns + ------- + init_states : TrainingSample + A training sample object containing the initial states, target + states, forcing and batch times. The batch times are the times of + the target steps. + + """ + ( + da_init_states, + da_target_states, + da_forcing_windowed, + da_boundary_windowed, + da_target_times, + ) = self._build_item_dataarrays(idx=idx) + + tensor_dtype = torch.float32 - batch_times = ( - self.state.isel(time=slice(idx + 2, idx + self.ar_steps)) - .time.values.astype(str) - .tolist() + init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) + target_states = torch.tensor( + da_target_states.values, dtype=tensor_dtype ) + target_times = torch.tensor( + da_target_times.astype("datetime64[ns]").astype("int64").values, + dtype=torch.int64, + ) + + forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) + # init_states: (2, N_grid, d_features) - # target_states: (ar_steps-2, N_grid, d_features) - # forcing: (ar_steps-2, N_grid, d_windowed_forcing) - # boundary: (ar_steps-2, N_grid, d_windowed_boundary) - # batch_times: (ar_steps-2,) - return init_states, target_states, forcing, boundary, batch_times + # target_states: (ar_steps, N_grid, d_features) + # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) + # target_times: (ar_steps,) + + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + + return init_states, target_states, forcing, boundary, target_times + + def __iter__(self): + """ + Convenience method to iterate over the dataset. + + This isn't used by pytorch DataLoader which itself implements an + iterator that uses Dataset.__getitem__ and Dataset.__len__. + + """ + for i in range(len(self)): + yield self[i] + + def create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[datetime.datetime, list[datetime.datetime]], + category: str, + ): + """ + Construct a xarray.DataArray from a `pytorch.Tensor` with coordinates + for `grid_index`, `time` and `{category}_feature` matching the shape + and number of times provided and add the x/y coordinates from the + datastore. + + The number if times provided is expected to match the shape of the + tensor. For a 2D tensor, the dimensions are assumed to be (grid_index, + {category}_feature) and only a single time should be provided. For a 3D + tensor, the dimensions are assumed to be (time, grid_index, + {category}_feature) and a list of times should be provided. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to construct the DataArray from, this assumed to have + the same dimension ordering as returned by the __getitem__ method + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. + time : datetime.datetime or list[datetime.datetime] + The time or times of the tensor. + category : str + The category of the tensor, either "state", "forcing" or "static". + + Returns + ------- + da : xr.DataArray + The constructed DataArray. + """ + + def _is_listlike(obj): + # match list, tuple, numpy array + return hasattr(obj, "__iter__") and not isinstance(obj, str) + + add_time_as_dim = False + if len(tensor.shape) == 2: + dims = ["grid_index", f"{category}_feature"] + if _is_listlike(time): + raise ValueError( + "Expected a single time for a 2D tensor with assumed " + "dimensions (grid_index, {category}_feature), but got " + f"{len(time)} times" + ) + elif len(tensor.shape) == 3: + add_time_as_dim = True + dims = ["time", "grid_index", f"{category}_feature"] + if not _is_listlike(time): + raise ValueError( + "Expected a list of times for a 3D tensor with assumed " + "dimensions (time, grid_index, {category}_feature), but " + "got a single time" + ) + else: + raise ValueError( + "Expected tensor to have 2 or 3 dimensions, but got " + f"{len(tensor.shape)}" + ) + + da_datastore_state = getattr(self, f"da_{category}") + da_grid_index = da_datastore_state.grid_index + da_state_feature = da_datastore_state.state_feature + + coords = { + f"{category}_feature": da_state_feature, + "grid_index": da_grid_index, + } + if add_time_as_dim: + coords["time"] = time + + tensor = tensor.detach().cpu().numpy() + da = xr.DataArray( + tensor, + dims=dims, + coords=coords, + ) + + for grid_coord in ["x", "y"]: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): + da.coords[grid_coord] = da_datastore_state[grid_coord] + + if not add_time_as_dim: + da.coords["time"] = time + + return da class WeatherDataModule(pl.LightningDataModule): @@ -126,38 +946,88 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, + datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, + standardize=True, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, + interior_subsample_step=1, + boundary_subsample_step=1, batch_size=4, num_workers=16, ): super().__init__() + self._datastore = datastore + self._datastore_boundary = datastore_boundary + self.num_past_forcing_steps = num_past_forcing_steps + self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps + self.interior_subsample_step = interior_subsample_step + self.boundary_subsample_step = boundary_subsample_step self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval + self.standardize = standardize self.batch_size = batch_size self.num_workers = num_workers self.train_dataset = None self.val_dataset = None self.test_dataset = None + if num_workers > 0: + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) + self.multiprocessing_context = "spawn" + else: + self.multiprocessing_context = None def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( + datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, - batch_size=self.batch_size, + standardize=self.standardize, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, + interior_subsample_step=self.interior_subsample_step, + boundary_subsample_step=self.boundary_subsample_step, ) self.val_dataset = WeatherDataset( + datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, - batch_size=self.batch_size, + standardize=self.standardize, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, + interior_subsample_step=self.interior_subsample_step, + boundary_subsample_step=self.boundary_subsample_step, ) if stage == "test" or stage is None: self.test_dataset = WeatherDataset( + datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, - batch_size=self.batch_size, + standardize=self.standardize, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, + interior_subsample_step=self.interior_subsample_step, + boundary_subsample_step=self.boundary_subsample_step, ) def train_dataloader(self): @@ -166,7 +1036,9 @@ def train_dataloader(self): self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, - shuffle=False, + shuffle=True, + multiprocessing_context=self.multiprocessing_context, + persistent_workers=True, ) def val_dataloader(self): @@ -176,6 +1048,8 @@ def val_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + multiprocessing_context=self.multiprocessing_context, + persistent_workers=True, ) def test_dataloader(self): @@ -185,4 +1059,6 @@ def test_dataloader(self): batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, + multiprocessing_context=self.multiprocessing_context, + persistent_workers=True, ) diff --git a/plot_graph.py b/plot_graph.py deleted file mode 100644 index 06167143..00000000 --- a/plot_graph.py +++ /dev/null @@ -1,213 +0,0 @@ -# Standard library -from argparse import ArgumentParser - -# Third-party -import numpy as np -import plotly.graph_objects as go -import torch_geometric as pyg - -# First-party -from neural_lam import utils - -MESH_HEIGHT = 0.1 -MESH_LEVEL_DIST = 0.2 -GRID_HEIGHT = 0 - - -def main(): - """ - Plot graph structure in 3D using plotly - """ - parser = ArgumentParser(description="Plot graph") - parser.add_argument( - "--graph", - type=str, - default="multiscale", - help="Graph to plot (default: multiscale)", - ) - parser.add_argument( - "--save", - type=str, - help="Name of .html file to save interactive plot to (default: None)", - ) - parser.add_argument( - "--show_axis", - type=int, - default=0, - help="If the axis should be displayed (default: 0 (No))", - ) - parser.add_argument( - "--data_config", - type=str, - default="neural_lam/data_config.yaml", - help="Path to data config file (default: neural_lam/data_config.yaml)", - ) - - args = parser.parse_args() - - # Load graph data - hierarchical, graph_ldict = utils.load_graph(args.graph) - ( - g2m_edge_index, - m2g_edge_index, - m2m_edge_index, - ) = ( - graph_ldict["g2m_edge_index"], - graph_ldict["m2g_edge_index"], - graph_ldict["m2m_edge_index"], - ) - mesh_up_edge_index, mesh_down_edge_index = ( - graph_ldict["mesh_up_edge_index"], - graph_ldict["mesh_down_edge_index"], - ) - mesh_static_features = graph_ldict["mesh_static_features"] - - config_loader = utils.ConfigLoader(args.data_config) - xy = config_loader.get_nwp_xy("static") - grid_xy = xy.transpose(1, 2, 0).reshape(-1, 2) # (N_grid, 2) - pos_max = np.max(np.abs(grid_xy)) - grid_pos = grid_xy / pos_max # Divide by maximum coordinate - - # Add in z-dimension - z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) - grid_pos = np.concatenate( - (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 - ) - - # List of edges to plot, (edge_index, color, line_width, label) - edge_plot_list = [ - (m2g_edge_index.numpy(), "black", 0.4, "M2G"), - (g2m_edge_index.numpy(), "black", 0.4, "G2M"), - ] - - # Mesh positioning and edges to plot differ if we have a hierarchical graph - if hierarchical: - mesh_level_pos = [ - np.concatenate( - ( - level_static_features.numpy(), - MESH_HEIGHT - + MESH_LEVEL_DIST - * height_level - * np.ones((level_static_features.shape[0], 1)), - ), - axis=1, - ) - for height_level, level_static_features in enumerate( - mesh_static_features, start=1 - ) - ] - mesh_pos = np.concatenate(mesh_level_pos, axis=0) - - # Add inter-level mesh edges - edge_plot_list += [ - (level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index) - ] - - # Add intra-level mesh edges - up_edges_ei = np.concatenate( - [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 - ) - down_edges_ei = np.concatenate( - [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], - axis=1, - ) - edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) - edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) - - mesh_node_size = 2.5 - else: - mesh_pos = mesh_static_features.numpy() - - mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy() - z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees - mesh_node_size = mesh_degrees / 2 - - mesh_pos = np.concatenate( - (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 - ) - - edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) - - # All node positions in one array - node_pos = np.concatenate((mesh_pos, grid_pos), axis=0) - - # Add edges - data_objs = [] - for ( - ei, - col, - width, - label, - ) in edge_plot_list: - edge_start = node_pos[ei[0]] # (M, 2) - edge_end = node_pos[ei[1]] # (M, 2) - n_edges = edge_start.shape[0] - - x_edges = np.stack( - (edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), axis=1 - ).flatten() - y_edges = np.stack( - (edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), axis=1 - ).flatten() - z_edges = np.stack( - (edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), axis=1 - ).flatten() - - scatter_obj = go.Scatter3d( - x=x_edges, - y=y_edges, - z=z_edges, - mode="lines", - line={"color": col, "width": width}, - name=label, - ) - data_objs.append(scatter_obj) - - # Add node objects - - data_objs.append( - go.Scatter3d( - x=grid_pos[:, 0], - y=grid_pos[:, 1], - z=grid_pos[:, 2], - mode="markers", - marker={"color": "black", "size": 1}, - name="Grid nodes", - ) - ) - data_objs.append( - go.Scatter3d( - x=mesh_pos[:, 0], - y=mesh_pos[:, 1], - z=mesh_pos[:, 2], - mode="markers", - marker={"color": "blue", "size": mesh_node_size}, - name="Mesh nodes", - ) - ) - - fig = go.Figure(data=data_objs) - - fig.update_layout(scene_aspectmode="data") - fig.update_traces(connectgaps=False) - - if not args.show_axis: - # Hide axis - fig.update_layout( - scene={ - "xaxis": {"visible": False}, - "yaxis": {"visible": False}, - "zaxis": {"visible": False}, - } - ) - - if args.save: - fig.write_html(args.save, include_plotlyjs="cdn") - else: - fig.show() - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 619f444f..2b6cd5af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,40 @@ [project] -name = "neural_lam" -version = "0.1.0" +name = "neural-lam" +version = "0.2.0" +description = "LAM-based data-driven forecasting" +authors = [ + { name = "Joel Oskarsson", email = "joel.oskarsson@liu.se" }, + { name = "Simon Adamov", email = "Simon.Adamov@meteoswiss.ch" }, + { name = "Leif Denby", email = "lcd@dmi.dk" }, +] + +# PEP 621 project metadata +# See https://www.python.org/dev/peps/pep-0621/ +dependencies = [ + "numpy>=1.24.2", + "wandb>=0.13.10", + "scipy>=1.10.0", + "pytorch-lightning>=2.0.3", + "shapely>=2.0.1", + "Cartopy>=0.22.0", + "pyproj>=3.4.1", + "tueplots>=0.0.8", + "matplotlib>=3.7.0", + "plotly>=5.15.0", + "torch>=2.3.0", + "torch-geometric==2.3.1", + "parse>=1.20.2", + "dataclass-wizard<0.31.0", + "mllam-data-prep>=0.5.0", + "weather-model-graphs @ git+https://github.com/joeloskarsson/weather-model-graphs.git@decoding_mask" +] +requires-python = ">=3.10" + +[project.optional-dependencies] +dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2", "gcsfs>=2021.10.0"] [tool.setuptools] -packages = ["neural_lam"] +py-modules = ["neural_lam"] [tool.black] line-length = 80 @@ -28,6 +59,7 @@ known_first_party = [ # Add first-party modules that may be misclassified by isort "neural_lam", ] +line_length = 80 [tool.flake8] max-line-length = 80 @@ -68,3 +100,9 @@ max-statements = 100 # Allow for some more involved functions allow-any-import-level = "neural_lam" [tool.pylint.SIMILARITIES] min-similarity-lines = 10 + + +[tool.pdm] +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7223ff0a..00000000 --- a/requirements.txt +++ /dev/null @@ -1,27 +0,0 @@ -# for all -numpy>=1.24.2 -wandb>=0.13.10 -matplotlib>=3.7.0 -scipy>=1.10.0 -pytorch-lightning>=2.0.3 -shapely>=2.0.1 -networkx>=3.0 -Cartopy>=0.22.0 -pyproj>=3.4.1 -tueplots>=0.0.8 -plotly>=5.15.0 -xarray>=0.20.1 -zarr>=2.10.0 -dask>=2022.0.0 -geopandas>=1.0.0 -anemoi-datasets>=0.4.0 -rasterio>=1.2.0 -affine>=2.3.0 -gcsfs>=2022.0.0 -# for dev -codespell>=2.0.0 -black>=21.9b0 -isort>=5.9.3 -flake8>=4.0.1 -pylint>=3.0.3 -pre-commit>=2.15.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ea06862e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,207 @@ +# Standard library +import os +from pathlib import Path + +# Third-party +import numpy as np +import pooch +import torch +import yaml + +# First-party +from neural_lam.datastore import DATASTORES, init_datastore +from neural_lam.datastore.npyfilesmeps import ( + compute_standardization_stats as compute_standardization_stats_meps, +) +from neural_lam.utils import get_stacked_xy + +# Local +from .dummy_datastore import DummyDatastore + +# Disable weights and biases to avoid unnecessary logging +# and to avoid having to deal with authentication +os.environ["WANDB_MODE"] = "disabled" + +DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples") + +# Initializing variables for the s3 client +S3_BUCKET_NAME = "mllam-testdata" +S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" +S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.2.0.zip" +S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) +TEST_DATA_KNOWN_HASH = ( + "7ff2e07e04cfcd77631115f800c9d49188bb2a7c2a2777da3cea219f926d0c86" +) + + +def download_meps_example_reduced_dataset(): + # Download and unzip test data into data/meps_example_reduced + root_path = DATASTORE_EXAMPLES_ROOT_PATH / "npyfilesmeps" + dataset_path = root_path / "meps_example_reduced" + + pooch.retrieve( + url=S3_FULL_PATH, + known_hash=TEST_DATA_KNOWN_HASH, + processor=pooch.Unzip(extract_dir=""), + path=root_path, + fname="meps_example_reduced.zip", + ) + + config_path = dataset_path / "meps_example_reduced.datastore.yaml" + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + if "class" in config["projection"]: + # XXX: should update the dataset stored on S3 with the change below + # + # rename the `projection.class` key to `projection.class_name` in the + # config this is because the `class` key is reserved for the class + # attribute of the object and so we can't use it to define a python + # dataclass + config["projection"]["class_name"] = config["projection"].pop("class") + + with open(config_path, "w") as f: + yaml.dump(config, f) + + # create parameters, only run if the files we expect are not present + expected_parameter_files = [ + "parameter_mean.pt", + "parameter_std.pt", + "diff_mean.pt", + "diff_std.pt", + ] + expected_parameter_filepaths = [ + dataset_path / "static" / fn for fn in expected_parameter_files + ] + if any(not p.exists() for p in expected_parameter_filepaths): + compute_standardization_stats_meps.main( + datastore_config_path=config_path, + batch_size=8, + step_length=3, + n_workers=0, + distributed=False, + ) + + return config_path + + +DATASTORES_EXAMPLES = dict( + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "danra_100m_winds" + / "danra.datastore.yaml" + ), + npyfilesmeps=download_meps_example_reduced_dataset(), + dummydata=None, +) + +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_danra_100m_winds" + / "era5.datastore.yaml" + ), +} + +DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore + + +def init_datastore_example(datastore_kind): + datastore = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_EXAMPLES[datastore_kind], + ) + return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary + + +def get_test_mesh_dist(datastore, datastore_boundary): + """Compute a good mesh_node_distance for testing graph creation with + given datastores + """ + xy = get_stacked_xy(datastore, datastore_boundary) # (num_grid, 2) + # Compute minimum coordinate extent + min_extent = min(np.ptp(xy, axis=0)) + + # Want at least 10 mesh nodes in each direction + return min_extent / 10.0 + + +def check_saved_graph(graph_dir_path, hierarchical, num_levels=1): + """Perform all checking for a saved graph""" + required_graph_files = [ + "m2m_edge_index.pt", + "g2m_edge_index.pt", + "m2g_edge_index.pt", + "m2m_features.pt", + "g2m_features.pt", + "m2g_features.pt", + "m2m_node_features.pt", + ] + + if hierarchical: + required_graph_files.extend( + [ + "mesh_up_edge_index.pt", + "mesh_down_edge_index.pt", + "mesh_up_features.pt", + "mesh_down_features.pt", + ] + ) + + # TODO: check that the number of edges is consistent over the files, for + # now we just check the number of features + d_features = 3 + d_mesh_static = 2 + + assert graph_dir_path.exists() + + # check that all the required files are present + for file_name in required_graph_files: + assert (graph_dir_path / file_name).exists() + + # try to load each and ensure they have the right shape + for file_name in required_graph_files: + file_id = Path(file_name).stem # remove the extension + result = torch.load(graph_dir_path / file_name, weights_only=True) + + if file_id.startswith("g2m") or file_id.startswith("m2g"): + assert isinstance(result, torch.Tensor) + + if file_id.endswith("_index"): + assert result.shape[0] == 2 # adjacency matrix uses two rows + elif file_id.endswith("_features"): + assert result.shape[1] == d_features + + elif file_id.startswith("m2m") or file_id.startswith("mesh"): + assert isinstance(result, list) + if not hierarchical: + assert len(result) == 1 + else: + if file_id.startswith("mesh_up") or file_id.startswith( + "mesh_down" + ): + assert len(result) == num_levels - 1 + else: + assert len(result) == num_levels + + for r in result: + assert isinstance(r, torch.Tensor) + + if file_id == "m2m_node_features": + assert r.shape[1] == d_mesh_static + elif file_id.endswith("_index"): + assert r.shape[0] == 2 # adjacency matrix uses two rows + elif file_id.endswith("_features"): + assert r.shape[1] == d_features diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore new file mode 100644 index 00000000..4fbd2326 --- /dev/null +++ b/tests/datastore_examples/.gitignore @@ -0,0 +1,3 @@ +npyfilesmeps/*.zip +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/danra_100m_winds/.gitignore b/tests/datastore_examples/mdp/danra_100m_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/danra_100m_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml new file mode 100644 index 00000000..8b3362e0 --- /dev/null +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 + t2m: 1.0 + r2m: 1.0 + output_clamping: + lower: + t2m: 0.0 + r2m: 0 + upper: + r2m: 1.0 + u100m: 100.0 diff --git a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..e601cc02 --- /dev/null +++ b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,117 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface_forcing: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + - r2m + - t2m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + state_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: state + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore new file mode 100644 index 00000000..f2828f46 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 00000000..a158bee3 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 00000000..3edf1267 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml new file mode 100644 index 00000000..c83489c6 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -0,0 +1,106 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + static: [grid_index, static_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_sea_level_pressure + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml new file mode 100644 index 00000000..27cc9764 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: npyfilesmeps + config_path: meps_example_reduced.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + nlwrs_0: 1.0 + nswrs_0: 1.0 + pres_0g: 1.0 + pres_0s: 1.0 + r_2: 1.0 + r_65: 1.0 + t_2: 1.0 + t_65: 1.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml new file mode 100644 index 00000000..c83489c6 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -0,0 +1,106 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + static: [grid_index, static_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_sea_level_pressure + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml new file mode 100644 index 00000000..3d88d4a4 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml @@ -0,0 +1,44 @@ +dataset: + name: meps_example_reduced + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + var_units: + - Pa + - Pa + - W/m**2 + - W/m**2 + - '' + - '' + - K + - K + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 +grid_shape_state: +- 134 +- 119 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py new file mode 100644 index 00000000..6b050e6b --- /dev/null +++ b/tests/dummy_datastore.py @@ -0,0 +1,433 @@ +# Standard library +import datetime +import tempfile +from functools import cached_property +from pathlib import Path +from typing import List, Union + +# Third-party +import isodate +import numpy as np +import xarray as xr +from cartopy import crs as ccrs +from numpy import ndarray + +# First-party +from neural_lam.datastore.base import ( + BaseRegularGridDatastore, + CartesianGridShape, +) + + +class DummyDatastore(BaseRegularGridDatastore): + """ + Datastore that creates some dummy data for testing purposes. The data + consists of state, forcing, and static variables, and is stored in a + regular grid (using Lambert Azimuthal Equal Area projection). The domain + is centered on Denmark and has a size of 500x500 km. + """ + + SHORT_NAME = "dummydata" + T0 = isodate.parse_datetime("1990-09-02T00:00:00") + N_FEATURES = dict(state=5, forcing=2, static=1) + CARTESIAN_COORDS = ["x", "y"] + + # center the domain on Denmark + latlon_center = [56, 10] # latitude, longitude + bbox_size_km = [500, 500] # km + + def __init__( + self, config_path=None, n_grid_points=10000, n_timesteps=15 + ) -> None: + """ + Create a dummy datastore with random data. + + Parameters + ---------- + config_path : None + No config file is needed for the dummy datastore. This argument is + only present to match the signature of the other datastores. + n_grid_points : int + The number of grid points in the dataset. Must be a perfect square. + n_timesteps : int + The number of timesteps in the dataset. + """ + assert ( + config_path is None + ), "No config file is needed for the dummy datastore" + + # Ensure n_grid_points is a perfect square + n_points_1d = int(np.sqrt(n_grid_points)) + assert ( + n_points_1d * n_points_1d == n_grid_points + ), "n_grid_points must be a perfect square" + + # create equal area grid + lx, ly = self.bbox_size_km + x = np.linspace(-lx / 2.0 * 1.0e3, lx / 2.0 * 1.0e3, n_points_1d) + y = np.linspace(-ly / 2.0 * 1.0e3, ly / 2.0 * 1.0e3, n_points_1d) + + xs, ys = np.meshgrid(x, y) + + # Create lat/lon coordinates using equal area projection + lon_mesh, lat_mesh = ( + ccrs.PlateCarree() + .transform_points( + src_crs=self.coords_projection, + x=xs.flatten(), + y=ys.flatten(), + )[:, :2] + .T + ) + + # Create base dataset with proper coordinates + self.ds = xr.Dataset( + coords={ + "x": ( + "x", + x, + {"units": "m"}, + ), # Use first column for x coordinates + "y": ( + "y", + y, + {"units": "m"}, + ), # Use first row for y coordinates + "longitude": ( + "grid_index", + lon_mesh.flatten(), + {"units": "degrees_east"}, + ), + "latitude": ( + "grid_index", + lat_mesh.flatten(), + {"units": "degrees_north"}, + ), + } + ) + # Create data variables with proper dimensions + for category, n in self.N_FEATURES.items(): + feature_names = [f"{category}_feat_{i}" for i in range(n)] + feature_units = ["-" for _ in range(n)] # Placeholder units + feature_long_names = [ + f"Long name for {name}" for name in feature_names + ] + + self.ds[f"{category}_feature"] = feature_names + self.ds[f"{category}_feature_units"] = ( + f"{category}_feature", + feature_units, + ) + self.ds[f"{category}_feature_long_name"] = ( + f"{category}_feature", + feature_long_names, + ) + + # Define dimensions and create random data + dims = ["grid_index", f"{category}_feature"] + if category != "static": + dims.append("time") + shape = (n_grid_points, n, n_timesteps) + else: + shape = (n_grid_points, n) + + # Create random data + data = np.random.randn(*shape) + + # Create DataArray with proper dimensions + self.ds[category] = xr.DataArray( + data, + dims=dims, + coords={ + f"{category}_feature": feature_names, + }, + ) + + if category != "static": + dt = datetime.timedelta(hours=self.step_length) + times = [self.T0 + dt * i for i in range(n_timesteps)] + self.ds.coords["time"] = times + + # Stack the spatial dimensions into grid_index + self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) + + # Create temporary directory for storing derived files + self._tempdir = tempfile.TemporaryDirectory() + self._root_path = Path(self._tempdir.name) + self._num_grid_points = n_grid_points + + @property + def root_path(self) -> Path: + """ + The root path to the datastore. It is relative to this that any derived + files (for example the graph components) are stored. + + Returns + ------- + pathlib.Path + The root path to the datastore. + + """ + return self._root_path + + @property + def config(self) -> dict: + """The configuration of the datastore. + + Returns + ------- + collections.abc.Mapping + The configuration of the datastore, any dict like object can be + returned. + + """ + return {} + + @property + def step_length(self) -> int: + """The step length of the dataset in hours. + + Returns: + int: The step length in hours. + + """ + return 1 + + def get_vars_names(self, category: str) -> list[str]: + """Get the names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The names of the variables. + + """ + return self.ds[f"{category}_feature"].values.tolist() + + def get_vars_units(self, category: str) -> list[str]: + """Get the units of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The units of the variables. + + """ + return self.ds[f"{category}_feature_units"].values.tolist() + + def get_vars_long_names(self, category: str) -> List[str]: + """Get the long names of the variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + List[str] + The long names of the variables. + + """ + return self.ds[f"{category}_feature_long_name"].values.tolist() + + def get_num_data_vars(self, category: str) -> int: + """Get the number of data variables in the given category. + + Parameters + ---------- + category : str + The category of the variables (state/forcing/static). + + Returns + ------- + int + The number of data variables. + + """ + return self.ds[f"{category}_feature"].size + + def get_standardization_dataarray(self, category: str) -> xr.Dataset: + """ + Return the standardization (i.e. scaling to mean of 0.0 and standard + deviation of 1.0) dataarray for the given category. This should contain + a `{category}_mean` and `{category}_std` variable for each variable in + the category. For `category=="state"`, the dataarray should also + contain a `state_diff_mean` and `state_diff_std` variable for the one- + step differences of the state variables. The returned dataarray should + at least have dimensions of `({category}_feature)`, but can also + include for example `grid_index` (if the standardization is done per + grid point for example). + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + xr.Dataset + The standardization dataarray for the given category, with variables + for the mean and standard deviation of the variables (and + differences for state variables). + + """ + ds_standardization = xr.Dataset() + + ops = ["mean", "std"] + if category == "state": + ops += ["diff_mean", "diff_std"] + + for op in ops: + da_op = xr.ones_like(self.ds[f"{category}_feature"]).astype(float) + ds_standardization[f"{category}_{op}"] = da_op + + return ds_standardization + + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> Union[xr.DataArray, None]: + """ + Return the processed data (as a single `xr.DataArray`) for the given + category of data and test/train/val-split that covers all the data (in + space and time) of a given category (state/forcing/static). For the + "static" category the `split` is allowed to be `None` because the static + data is the same for all splits. + + The returned dataarray is expected to at minimum have dimensions of + `(grid_index, {category}_feature)` so that any spatial dimensions have + been stacked into a single dimension and all variables and levels have + been stacked into a single feature dimension named by the `category` of + data being loaded. + + For categories of data that have a time dimension (i.e. not static + data), the dataarray is expected additionally have `(analysis_time, + elapsed_forecast_duration)` dimensions if `is_forecast` is True, or + `(time)` if `is_forecast` is False. + + If the data is ensemble data, the dataarray is expected to have an + additional `ensemble_member` dimension. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + split : str + The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized + + Returns + ------- + xr.DataArray or None + The xarray DataArray object with processed dataset. + + """ + dim_order = self.expected_dim_order(category=category) + + da_category = self.ds[category].transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da_category, category=category) + + return da_category + + def get_xy(self, category: str, stacked: bool = True) -> ndarray: + """Return the x, y coordinates of the dataset. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + stacked : bool + Whether to stack the x, y coordinates. + + Returns + ------- + np.ndarray + The x, y coordinates of the dataset, returned differently based on + the value of `stacked`: + - `stacked==True`: shape `(n_grid_points, 2)` where + n_grid_points=N_x*N_y. + - `stacked==False`: shape `(N_x, N_y, 2)` + + """ + # assume variables are stored in dimensions [grid_index, ...] + ds_category = self.unstack_grid_coords(da_or_ds=self.ds[category]) + + da_xs = ds_category.x + da_ys = ds_category.y + + assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D" + + da_x, da_y = xr.broadcast(da_xs, da_ys) + da_xy = xr.concat([da_x, da_y], dim="grid_coord") + + if stacked: + da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose( + "grid_index", + "grid_coord", + ) + else: + dims = [ + "x", + "y", + "grid_coord", + ] + da_xy = da_xy.transpose(*dims) + + return da_xy.values + + @property + def coords_projection(self) -> ccrs.Projection: + """Return the projection object for the coordinates. + + The projection object is used to plot the coordinates on a map. + + Returns + ------- + cartopy.crs.Projection: + The projection object. + + """ + # make a projection centered on Denmark + lat_center, lon_center = self.latlon_center + return ccrs.LambertAzimuthalEqualArea( + central_latitude=lat_center, central_longitude=lon_center + ) + + @property + def num_grid_points(self) -> int: + """Return the number of grid points in the dataset. + + Returns + ------- + int + The number of grid points in the dataset. + + """ + return self._num_grid_points + + @cached_property + def grid_shape_state(self) -> CartesianGridShape: + """The shape of the grid for the state variables. + + Returns + ------- + CartesianGridShape: + The shape of the grid for the state variables, which has `x` and + `y` attributes. + """ + + n_points_1d = int(np.sqrt(self.num_grid_points)) + return CartesianGridShape(x=n_points_1d, y=n_points_1d) diff --git a/tests/test_clamping.py b/tests/test_clamping.py new file mode 100644 index 00000000..f3f9365d --- /dev/null +++ b/tests/test_clamping.py @@ -0,0 +1,283 @@ +# Standard library +from pathlib import Path + +# Third-party +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.datastore.mdp import MDPDatastore +from neural_lam.models.graph_lam import GraphLAM +from tests.conftest import init_datastore_example + + +def test_clamping(): + datastore = init_datastore_example(MDPDatastore.SHORT_NAME) + + graph_name = "1level" + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + graph = graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 2 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1, 3] + metrics_watch = [] + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + model_args = ModelArgs() + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ), + training=nlconfig.TrainingConfig( + output_clamping=nlconfig.OutputClamping( + lower={"t2m": 0.0, "r2m": 0.0}, + upper={"r2m": 1.0, "u100m": 100.0}, + ) + ), + ) + + model = GraphLAM( + args=model_args, + datastore=datastore, + config=config, + ) + + features = datastore.get_vars_names(category="state") + original_state = torch.zeros(1, 1, len(features)) + zero_delta = original_state.clone() + + # Get a state well within the bounds + original_state[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_lower_lims + model.sigmoid_upper_lims + ) / 2 + original_state[:, :, model.clamp_lower_idx] = model.softplus_lower_lims + 10 + original_state[:, :, model.clamp_upper_idx] = model.softplus_upper_lims - 10 + + # Get a delta that tries to push the state out of bounds + delta = torch.ones_like(zero_delta) + delta[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_upper_lims - model.sigmoid_lower_lims + ) / 3 + delta[:, :, model.clamp_lower_idx] = -5 + delta[:, :, model.clamp_upper_idx] = 5 + + # Check that a delta of 0 gives unchanged state + zero_prediction = model.get_clamped_new_state(zero_delta, original_state) + assert (abs(original_state - zero_prediction) < 1e-6).all().item() + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.get_clamped_new_state(delta, prediction) + + # check that unclamped states are as expected + # delta is 1, so they should be 1*n_loops + assert ( + ( + abs( + prediction[ + :, + :, + list( + set(range(len(features))) + - set(model.clamp_lower_upper_idx.tolist()) + - set(model.clamp_lower_idx.tolist()) + - set(model.clamp_upper_idx.tolist()) + ), + ] + - n_loops + ) + < 1e-6 + ) + .all() + .item() + ) + + # Check that clamped states are within bounds + # they should not be at the bounds but allow it due to numerical precision + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + (model.softplus_lower_lims <= prediction[:, :, model.clamp_lower_idx]) + .all() + .item() + ) + assert ( + (prediction[:, :, model.clamp_upper_idx] <= model.softplus_upper_lims) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + unscaled_prediction = prediction * model.state_std + model.state_mean + features_idx = {f: i for i, f in enumerate(features)} + lower_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.lower.items() + } + upper_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.upper.items() + } + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state + 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + assert ( + not ( + model.softplus_lower_lims + <= invalid_state[:, :, model.clamp_lower_idx] + ) + .any() + .item() + ) + assert ( + not ( + invalid_state[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.get_clamped_new_state(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + ( + model.softplus_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_idx] + ) + .all() + .item() + ) + assert ( + ( + invalid_prediction[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .all() + .item() + ) + + # Above tests only check the upper sigmoid limit. + # Repeat to check lower sigmoid limit + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.get_clamped_new_state(-delta, prediction) + + # Check that clamped states are within bounds + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state - 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.get_clamped_new_state(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..cd6b00eb --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,12 @@ +# First-party +import neural_lam +import neural_lam.build_rectangular_graph +import neural_lam.train_model + + +def test_import(): + """This test just ensures that each cli entry-point can be imported for now, + eventually we should test their execution too.""" + assert neural_lam is not None + assert neural_lam.build_rectangular_graph is not None + assert neural_lam.train_model is not None diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..1ff40bc6 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,72 @@ +# Third-party +import pytest + +# First-party +import neural_lam.config as nlconfig + + +@pytest.mark.parametrize( + "state_weighting_config", + [ + nlconfig.ManualStateFeatureWeighting( + weights=dict(u100m=1.0, v100m=0.5) + ), + nlconfig.UniformFeatureWeighting(), + ], +) +def test_config_serialization(state_weighting_config): + c = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""), + training=nlconfig.TrainingConfig( + state_feature_weighting=state_weighting_config + ), + ) + + assert c == c.from_json(c.to_json()) + assert c == c.from_yaml(c.to_yaml()) + + +yaml_training_defaults = """ +datastore: + kind: mdp + config_path: "" +""" + +default_config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""), + training=nlconfig.TrainingConfig( + state_feature_weighting=nlconfig.UniformFeatureWeighting() + ), +) + +yaml_training_manual_weights = """ +datastore: + kind: mdp + config_path: "" +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 +""" + +manual_weights_config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""), + training=nlconfig.TrainingConfig( + state_feature_weighting=nlconfig.ManualStateFeatureWeighting( + weights=dict(u100m=1.0, v100m=1.0) + ) + ), +) + +yaml_samples = zip( + [yaml_training_defaults, yaml_training_manual_weights], + [default_config, manual_weights_config], +) + + +@pytest.mark.parametrize("yaml_str, config_expected", yaml_samples) +def test_config_load_from_yaml(yaml_str, config_expected): + c = nlconfig.NeuralLAMConfig.from_yaml(yaml_str) + assert c == config_expected diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 00000000..757127a6 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,315 @@ +# Standard library +from pathlib import Path + +# Third-party +import numpy as np +import pytest +import torch +from torch.utils.data import DataLoader + +# First-party +from neural_lam import config as nlconfig +from neural_lam.build_rectangular_graph import build_graph_from_archetype +from neural_lam.datastore import DATASTORES +from neural_lam.datastore.base import BaseRegularGridDatastore +from neural_lam.models.graph_lam import GraphLAM +from neural_lam.weather_dataset import WeatherDataset +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + get_test_mesh_dist, + init_datastore_boundary_example, + init_datastore_example, +) +from tests.dummy_datastore import DummyDatastore + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): + """Check that the `datastore.get_dataarray` method is implemented. + + Validate the shapes of the tensors match between the different + components of the training sample. + + init_states: (2, N_grid, d_features) + target_states: (ar_steps, N_grid, d_features) + forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,) + + """ + datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points + + N_pred_steps = 4 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=datastore_boundary, + split="train", + ar_steps=N_pred_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, + ) + + item = dataset[0] + + # unpack the item, this is the current return signature for + # WeatherDataset.__getitem__ + init_states, target_states, forcing, boundary, target_times = item + + # initial states + assert init_states.ndim == 3 + assert init_states.shape[0] == 2 # two time steps go into the input + assert init_states.shape[1] == N_gridpoints + assert init_states.shape[2] == datastore.get_num_data_vars("state") + + # output states + assert target_states.ndim == 3 + assert target_states.shape[0] == N_pred_steps + assert target_states.shape[1] == N_gridpoints + assert target_states.shape[2] == datastore.get_num_data_vars("state") + + # forcing + assert forcing.ndim == 3 + assert forcing.shape[0] == N_pred_steps + assert forcing.shape[1] == N_gridpoints + # each time step in the window has one corresponding time deltas + # that is shared across all grid points, times and variables + assert forcing.shape[2] == (datastore.get_num_data_vars("forcing") + 1) * ( + num_past_forcing_steps + num_future_forcing_steps + 1 + ) + + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert boundary.shape[2] == ( + datastore_boundary.get_num_data_vars("forcing") + 1 + ) * (num_past_boundary_steps + num_future_boundary_steps + 1) + + # batch times + assert target_times.ndim == 1 + assert target_times.shape[0] == N_pred_steps + + # try to get the last item of the dataset to ensure slicing and stacking + # operations are working as expected and are consistent with the dataset + # length + dataset[len(dataset) - 1] + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_dataset_item_create_dataarray_from_tensor(datastore_name): + datastore = init_datastore_example(datastore_name) + + N_pred_steps = 4 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=N_pred_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + ) + + idx = 0 + + # unpack the item, this is the current return signature for + # WeatherDataset.__getitem__ + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) + + target_times = np.array(target_times_arr, dtype="datetime64[ns]") + np.testing.assert_equal(target_times, da_target_times_true.values) + + da_target = dataset.create_dataarray_from_tensor( + tensor=target_states, category="state", time=target_times + ) + + # conversion to torch.float32 may lead to loss of precision + np.testing.assert_allclose( + da_target.values, da_target_true.values, rtol=1e-6 + ) + assert da_target.dims == da_target_true.dims + for dim in da_target.dims: + np.testing.assert_equal( + da_target[dim].values, da_target_true[dim].values + ) + + if isinstance(datastore, BaseRegularGridDatastore): + # test unstacking the grid coordinates + da_target_unstacked = datastore.unstack_grid_coords(da_target) + assert all( + coord_name in da_target_unstacked.coords + for coord_name in ["x", "y"] + ) + + # check construction of a single time + da_target_single = dataset.create_dataarray_from_tensor( + tensor=target_states[0], category="state", time=target_times[0] + ) + + # check that the content is the same + # conversion to torch.float32 may lead to loss of precision + np.testing.assert_allclose( + da_target_single.values, da_target_true[0].values, rtol=1e-6 + ) + assert da_target_single.dims == da_target_true[0].dims + for dim in da_target_single.dims: + np.testing.assert_equal( + da_target_single[dim].values, da_target_true[0][dim].values + ) + + if isinstance(datastore, BaseRegularGridDatastore): + # test unstacking the grid coordinates + da_target_single_unstacked = datastore.unstack_grid_coords( + da_target_single + ) + assert all( + coord_name in da_target_single_unstacked.coords + for coord_name in ["x", "y"] + ) + + +@pytest.mark.parametrize("split", ["train", "val", "test"]) +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): + """Check that the `datastore.get_dataarray` method is implemented. + + And that it returns an xarray DataArray with the correct dimensions. + + """ + datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + + device_name = ( + torch.device("cuda") if torch.cuda.is_available() else "cpu" + ) # noqa + + flat_graph_name = "1level" + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + graph_name = flat_graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 2 + mesh_aggr = "sum" + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + shared_grid_embedder = False + + args = ModelArgs() + + graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name + + def _create_graph(): + if not graph_dir_path.exists(): + build_graph_from_archetype( + datastore=datastore, + datastore_boundary=datastore_boundary, + graph_name=flat_graph_name, + archetype="keisler", + mesh_node_distance=get_test_mesh_dist( + datastore, datastore_boundary + ), + ) + + if not isinstance(datastore, BaseRegularGridDatastore): + with pytest.raises(NotImplementedError): + _create_graph() + pytest.skip("Skipping on model-run on non-regular grid datastores") + + _create_graph() + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) + + model = GraphLAM( + args=args, + datastore=datastore, + datastore_boundary=datastore_boundary, + config=config, + ) # noqa + + model_device = model.to(device_name) + data_loader = DataLoader(dataset, batch_size=2) + batch = next(iter(data_loader)) + batch_device = [part.to(device_name) for part in batch] + model_device.common_step(batch_device) + model_device.training_step(batch_device) + + +@pytest.mark.parametrize( + "dataset_config", + [ + {"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 3}, + {"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 3}, + {"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 5}, + {"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 5}, + {"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 7}, + {"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 8}, + ], +) +def test_dataset_length(dataset_config): + """Check that correct number of samples can be extracted from the dataset, + given a specific configuration of forcing windowing and ar_steps. + """ + # Use dummy datastore of length 10 here, only want to test slicing + # in dataset class + ds_len = 10 + datastore = DummyDatastore(n_timesteps=ds_len) + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=dataset_config["ar_steps"], + num_past_forcing_steps=dataset_config["past"], + num_future_forcing_steps=dataset_config["future"], + ) + + # We expect dataset to contain this many samples + expected_len = ds_len - dataset_config["exp_len_reduction"] + + # Check that datast has correct length + assert len(dataset) == expected_len + + # Check that we can actually get last and first sample + dataset[0] + dataset[expected_len - 1] diff --git a/tests/test_datastores.py b/tests/test_datastores.py new file mode 100644 index 00000000..ff7435c9 --- /dev/null +++ b/tests/test_datastores.py @@ -0,0 +1,384 @@ +"""List of methods and attributes that should be implemented in a subclass of +`` (these are all decorated with `@abc.abstractmethod`): + +- `root_path` (property): Root path of the datastore. +- `step_length` (property): Length of the time step in hours. +- `grid_shape_state` (property): Shape of the grid for the state variables. +- `get_xy` (method): Return the x, y coordinates of the dataset. +- `coords_projection` (property): Projection object for the coordinates. +- `get_vars_units` (method): Get the units of the variables in the given + category. +- `get_vars_names` (method): Get the names of the variables in the given + category. +- `get_vars_long_names` (method): Get the long names of the variables in + the given category. +- `get_num_data_vars` (method): Get the number of data variables in the + given category. +- `get_normalization_dataarray` (method): Return the normalization + dataarray for the given category. +- `get_dataarray` (method): Return the processed data (as a single + `xr.DataArray`) for the given category and test/train/val-split. +- `config` (property): Return the configuration of the datastore. + +In addition BaseRegularGridDatastore must have the following methods and +attributes: +- `get_xy_extent` (method): Return the extent of the x, y coordinates for a + given category of data. +- `get_xy` (method): Return the x, y coordinates of the dataset. +- `coords_projection` (property): Projection object for the coordinates. +- `grid_shape_state` (property): Shape of the grid for the state variables. +- `stack_grid_coords` (method): Stack the grid coordinates of the dataset + +""" + +# Standard library +import collections +import dataclasses +from pathlib import Path + +# Third-party +import cartopy.crs as ccrs +import numpy as np +import pytest +import torch +import xarray as xr + +# First-party +from neural_lam.datastore import DATASTORES +from neural_lam.datastore.base import BaseRegularGridDatastore +from neural_lam.datastore.plot_example import plot_example_from_datastore +from tests.conftest import init_datastore_example + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_root_path(datastore_name): + """Check that the `datastore.root_path` property is implemented.""" + datastore = init_datastore_example(datastore_name) + assert isinstance(datastore.root_path, Path) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_config(datastore_name): + """Check that the `datastore.config` property is implemented.""" + datastore = init_datastore_example(datastore_name) + # check the config is a mapping or a dataclass + config = datastore.config + assert isinstance( + config, collections.abc.Mapping + ) or dataclasses.is_dataclass(config) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_step_length(datastore_name): + """Check that the `datastore.step_length` property is implemented.""" + datastore = init_datastore_example(datastore_name) + step_length = datastore.step_length + assert isinstance(step_length, int) + assert step_length > 0 + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_datastore_grid_xy(datastore_name): + """Use the `datastore.get_xy` method to get the x, y coordinates of the + dataset and check that the shape is correct against the `da + tastore.grid_shape_state` property.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip( + "Skip grid_shape_state test for non-regular grid datastores" + ) + + # check the shapes of the xy grid + grid_shape = datastore.grid_shape_state + nx, ny = grid_shape.x, grid_shape.y + for stacked in [True, False]: + xy = datastore.get_xy("static", stacked=stacked) + if stacked: + assert xy.shape == (nx * ny, 2) + else: + assert xy.shape == (nx, ny, 2) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_vars(datastore_name): + """Check that results of. + + - `datastore.get_vars_units` + - `datastore.get_vars_names` + - `datastore.get_vars_long_names` + - `datastore.get_num_data_vars` + + are consistent (as in the number of variables are the same) and that the + return types of each are correct. + + """ + datastore = init_datastore_example(datastore_name) + + for category in ["state", "forcing", "static"]: + units = datastore.get_vars_units(category) + names = datastore.get_vars_names(category) + long_names = datastore.get_vars_long_names(category) + num_vars = datastore.get_num_data_vars(category) + + assert len(units) == len(names) == num_vars + assert isinstance(units, list) + assert isinstance(names, list) + assert isinstance(long_names, list) + assert isinstance(num_vars, int) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_normalization_dataarray(datastore_name): + """Check that the `datastore.get_normalization_dataa rray` method is + implemented.""" + datastore = init_datastore_example(datastore_name) + + for category in ["state", "forcing", "static"]: + ds_stats = datastore.get_standardization_dataarray(category=category) + + # check that the returned object is an xarray DataArray + # and that it has the correct variables + assert isinstance(ds_stats, xr.Dataset) + + if category == "state": + ops = ["mean", "std", "diff_mean", "diff_std"] + elif category == "forcing": + ops = ["mean", "std"] + elif category == "static": + ops = [] + else: + raise NotImplementedError(category) + + for op in ops: + var_name = f"{category}_{op}" + assert var_name in ds_stats.data_vars + da_val = ds_stats[var_name] + assert set(da_val.dims) == {f"{category}_feature"} + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_dataarray(datastore_name): + """Check that the `datastore.get_dataarray` method is implemented. + + And that it returns an xarray DataArray with the correct dimensions. + + """ + + datastore = init_datastore_example(datastore_name) + + for category in ["state", "forcing", "static"]: + n_features = {} + if category in ["state", "forcing"]: + splits = ["train", "val", "test"] + elif category == "static": + # static data should be the same for all splits, so split + # should be allowed to be None + splits = ["train", "val", "test", None] + else: + raise NotImplementedError(category) + + for split in splits: + expected_dims = ["grid_index", f"{category}_feature"] + if category != "static": + if not datastore.is_forecast: + expected_dims.append("time") + else: + expected_dims += [ + "analysis_time", + "elapsed_forecast_duration", + ] + + if datastore.is_ensemble and category == "state": + # assume that only state variables change with ensemble members + expected_dims.append("ensemble_member") + + # XXX: for now we only have a single attribute to get the shape of + # the grid which uses the shape from the "state" category, maybe + # this should change? + + da = datastore.get_dataarray(category=category, split=split) + + assert isinstance(da, xr.DataArray) + assert set(da.dims) == set(expected_dims) + if isinstance(datastore, BaseRegularGridDatastore): + grid_shape = datastore.grid_shape_state + assert da.grid_index.size == grid_shape.x * grid_shape.y + + n_features[split] = da[category + "_feature"].size + + # check that the number of features is the same for all splits + assert n_features["train"] == n_features["val"] == n_features["test"] + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_xy_extent(datastore_name): + """Check that the `datastore.get_xy_extent` method is implemented and that + the returned object is a tuple of the correct length.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + extents = {} + # get the extents for each category, and finally check they are all the same + for category in ["state", "forcing", "static"]: + extent = datastore.get_xy_extent(category) + assert isinstance(extent, list) + assert len(extent) == 4 + assert all(isinstance(e, (int, float)) for e in extent) + extents[category] = extent + + # check that the extents are the same for all categories + for category in ["forcing", "static"]: + assert extents["state"] == extents[category] + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_xy(datastore_name): + """Check that the `datastore.get_xy` method is implemented.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + for category in ["state", "forcing", "static"]: + xy_stacked = datastore.get_xy(category=category, stacked=True) + xy_unstacked = datastore.get_xy(category=category, stacked=False) + + assert isinstance(xy_stacked, np.ndarray) + assert isinstance(xy_unstacked, np.ndarray) + + nx, ny = datastore.grid_shape_state.x, datastore.grid_shape_state.y + + # for stacked=True, the shape should be (n_grid_points, 2) + assert xy_stacked.ndim == 2 + assert xy_stacked.shape[0] == nx * ny + assert xy_stacked.shape[1] == 2 + + # for stacked=False, the shape should be (nx, ny, 2) + assert xy_unstacked.ndim == 3 + assert xy_unstacked.shape[0] == nx + assert xy_unstacked.shape[1] == ny + assert xy_unstacked.shape[2] == 2 + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_get_projection(datastore_name): + """Check that the `datastore.coords_projection` property is implemented.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + assert isinstance(datastore.coords_projection, ccrs.Projection) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def get_grid_shape_state(datastore_name): + """Check that the `datastore.grid_shape_state` property is implemented.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + grid_shape = datastore.grid_shape_state + assert isinstance(grid_shape, tuple) + assert len(grid_shape) == 2 + assert all(isinstance(e, int) for e in grid_shape) + assert all(e > 0 for e in grid_shape) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize("category", ["state", "forcing", "static"]) +def test_stacking_grid_coords(datastore_name, category): + """Check that the `datastore.stack_grid_coords` method is implemented.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + da_static = datastore.get_dataarray(category=category, split="train") + + da_static_unstacked = datastore.unstack_grid_coords(da_static).load() + da_static_test = datastore.stack_grid_coords(da_static_unstacked) + + assert da_static.dims == da_static_test.dims + xr.testing.assert_equal(da_static, da_static_test) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_dataarray_shapes(datastore_name): + datastore = init_datastore_example(datastore_name) + static_da = datastore.get_dataarray("static", split=None) + static_da = datastore.stack_grid_coords(static_da) + static_da = static_da.isel(static_feature=0) + + # Convert the unstacked grid coordinates and static data array to tensors + unstacked_tensor = torch.tensor( + datastore.unstack_grid_coords(static_da).to_numpy(), dtype=torch.float32 + ).squeeze() + + reshaped_tensor = ( + torch.tensor(static_da.to_numpy(), dtype=torch.float32) + .reshape(datastore.grid_shape_state.x, datastore.grid_shape_state.y) + .squeeze() + ) + + # Compute the difference + diff = unstacked_tensor - reshaped_tensor + + # Check the shapes + assert unstacked_tensor.shape == ( + datastore.grid_shape_state.x, + datastore.grid_shape_state.y, + ) + assert reshaped_tensor.shape == ( + datastore.grid_shape_state.x, + datastore.grid_shape_state.y, + ) + assert diff.shape == ( + datastore.grid_shape_state.x, + datastore.grid_shape_state.y, + ) + # assert diff == 0 with tolerance 1e-6 + assert torch.allclose(diff, torch.zeros_like(diff), atol=1e-6) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_plot_example_from_datastore(datastore_name): + """Check that the `plot_example_from_datastore` function is implemented.""" + datastore = init_datastore_example(datastore_name) + fig = plot_example_from_datastore( + category="static", + datastore=datastore, + col_dim="{category}_feature", + split="train", + standardize=True, + selection={}, + index_selection={}, + ) + + assert fig is not None + assert fig.get_axes() + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize("category", ("state", "static")) +def test_get_standardized_da(datastore_name, category): + """Check that dataarray is actually standardized when calling + get_dataarray with standardize=True""" + datastore = init_datastore_example(datastore_name) + ds_stats = datastore.get_standardization_dataarray(category=category) + + mean = ds_stats[f"{category}_mean"] + std = ds_stats[f"{category}_std"] + + non_standard_da = datastore.get_dataarray( + category=category, split="train", standardize=False + ) + standard_da = datastore.get_dataarray( + category=category, split="train", standardize=True + ) + + assert np.allclose(standard_da, (non_standard_da - mean) / std, atol=1e-6) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py new file mode 100644 index 00000000..a2335dfa --- /dev/null +++ b/tests/test_graph_creation.py @@ -0,0 +1,169 @@ +# Standard library +import tempfile +from pathlib import Path + +# Third-party +import pytest + +# First-party +from neural_lam.build_rectangular_graph import ( + build_graph, + build_graph_from_archetype, +) +from neural_lam.datastore import DATASTORES +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + check_saved_graph, + get_test_mesh_dist, + init_datastore_boundary_example, + init_datastore_example, +) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", + list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None], +) +@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"]) +def test_build_archetype(datastore_name, datastore_boundary_name, archetype): + """Check that the `build_graph_from_archetype` function is implemented. + And that the graph is created in the correct location. + """ + datastore = init_datastore_example(datastore_name) + + if datastore_boundary_name is None: + # LAM scale + datastore_boundary = None + else: + # Global scale, ERA5 coords flattened with proj + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + + create_kwargs = { + "mesh_node_distance": get_test_mesh_dist(datastore, datastore_boundary), + } + + if archetype == "keisler": + num_levels = 1 + else: + # Add additional multi-level kwargs + num_levels = 2 + create_kwargs.update( + { + "level_refinement_factor": 3, + "max_num_levels": num_levels, + } + ) + + # Name graph + graph_name = f"{datastore_name}_{datastore_boundary_name}_{archetype}" + + # Saved in temporary dir + with tempfile.TemporaryDirectory() as tmpdir: + graph_saving_path = Path(tmpdir) / "graphs" + graph_dir_path = graph_saving_path / graph_name + + build_graph_from_archetype( + datastore, + datastore_boundary, + graph_name, + archetype, + dir_save_path=graph_saving_path, + **create_kwargs, + ) + + hierarchical = archetype == "hierarchical" + check_saved_graph(graph_dir_path, hierarchical, num_levels) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", + list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None], +) +@pytest.mark.parametrize( + "config_i, graph_kwargs", + enumerate( + [ + # Assortment of options + { + "m2m_connectivity": "flat", + "m2g_connectivity": "nearest_neighbour", + "g2m_connectivity": "nearest_neighbour", + "m2m_connectivity_kwargs": {}, + }, + { + "m2m_connectivity": "flat_multiscale", + "m2g_connectivity": "nearest_neighbours", + "g2m_connectivity": "within_radius", + "m2m_connectivity_kwargs": { + "level_refinement_factor": 3, + "max_num_levels": None, + }, + "m2g_connectivity_kwargs": { + "max_num_neighbours": 4, + }, + "g2m_connectivity_kwargs": { + "rel_max_dist": 0.3, + }, + }, + { + "m2m_connectivity": "hierarchical", + "m2g_connectivity": "containing_rectangle", + "g2m_connectivity": "within_radius", + "m2m_connectivity_kwargs": { + "level_refinement_factor": 2, + "max_num_levels": 2, + }, + "m2g_connectivity_kwargs": {}, + "g2m_connectivity_kwargs": { + "rel_max_dist": 0.51, + }, + }, + ] + ), +) +def test_build_from_options( + datastore_name, datastore_boundary_name, config_i, graph_kwargs +): + """Check that the `build_graph_from_archetype` function is implemented. + And that the graph is created in the correct location. + + """ + datastore = init_datastore_example(datastore_name) + + if datastore_boundary_name is None: + # LAM scale + datastore_boundary = None + else: + # Global scale, ERA5 coords flattened with proj + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + + # Insert mesh distance + graph_kwargs["m2m_connectivity_kwargs"][ + "mesh_node_distance" + ] = get_test_mesh_dist(datastore, datastore_boundary) + + # Name graph + graph_name = f"{datastore_name}_{datastore_boundary_name}_config{config_i}" + + # Save in temporary dir + with tempfile.TemporaryDirectory() as tmpdir: + graph_saving_path = Path(tmpdir) / "graphs" + graph_dir_path = graph_saving_path / graph_name + + build_graph( + datastore, + datastore_boundary, + graph_name, + dir_save_path=graph_saving_path, + **graph_kwargs, + ) + + hierarchical = graph_kwargs["m2m_connectivity"] == "hierarchical" + num_levels = 2 if hierarchical else 1 + check_saved_graph(graph_dir_path, hierarchical, num_levels) diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..e7bbd356 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,8 @@ +# First-party +import neural_lam +import neural_lam.vis + + +def test_import(): + assert neural_lam is not None + assert neural_lam.vis is not None diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py new file mode 100644 index 00000000..9345c04b --- /dev/null +++ b/tests/test_time_slicing.py @@ -0,0 +1,540 @@ +# Third-party +import numpy as np +import pytest +import xarray as xr + +# First-party +from neural_lam.datastore import DATASTORES +from neural_lam.datastore.base import BaseDatastore +from neural_lam.weather_dataset import WeatherDataset +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) + + +class SinglePointDummyDatastore(BaseDatastore): + step_length = 1 + config = None + coords_projection = None + num_grid_points = 1 + root_path = None + + def __init__(self, time_values, state_data, forcing_data, is_forecast): + self.is_forecast = is_forecast + if is_forecast: + self._analysis_times, self._forecast_times = time_values + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + # state_data and forcing_data should be 2D arrays with shape + # (n_analysis_times, n_forecast_times) + else: + self._time_values = np.array(time_values) + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + + if is_forecast: + assert self._state_data.ndim == 2 + else: + assert self._state_data.ndim == 1 + + def get_num_data_vars(self, category): + return 1 + + def get_dataarray(self, category, split): + if self.is_forecast: + if category == "state": + # Create DataArray with dims ('analysis_time', + # 'elapsed_forecast_duration') + da = xr.DataArray( + self._state_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + elif category == "forcing": + da = xr.DataArray( + self._forcing_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + else: + raise NotImplementedError(category) + # Add 'grid_index' and '{category}_feature' dimensions + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) + else: + if category == "state": + values = self._state_data + elif category == "forcing": + values = self._forcing_data + else: + raise NotImplementedError(category) + + if self.is_forecast: + raise NotImplementedError() + else: + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) + + # add `{category}_feature` and `grid_index` dimensions + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") + + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) + + def get_standardization_dataarray(self, category): + raise NotImplementedError() + + def get_xy(self, category): + raise NotImplementedError() + + def get_vars_units(self, category): + raise NotImplementedError() + + def get_vars_names(self, category): + raise NotImplementedError() + + def get_vars_long_names(self, category): + raise NotImplementedError() + + +class BoundaryDummyDatastore(SinglePointDummyDatastore): + """Dummy datastore with 6h timesteps for testing boundary conditions""" + + step_length = 6 # 6 hour timesteps + + +INIT_STEPS = 2 + +STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + +STATE_VALUES_FORECAST = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # Analysis time 0 + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], # Analysis time 1 + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], # Analysis time 2 +] +FORCING_VALUES_FORECAST = [ + [100, 101, 102, 103, 104, 105, 106, 107, 108, 109], # Analysis time 0 + [110, 111, 112, 113, 114, 115, 116, 117, 118, 119], # Analysis time 1 + [120, 121, 122, 123, 124, 125, 126, 127, 128, 129], # Analysis time 2 +] + +SCENARIOS = [ + [3, 0, 0], + [3, 1, 0], + [3, 2, 0], + [3, 3, 0], + [3, 0, 1], + [3, 0, 2], + [3, 0, 3], + [3, 1, 1], + [3, 2, 1], + [3, 3, 1], + [3, 1, 2], + [3, 1, 3], + [3, 2, 2], + [3, 2, 3], + [3, 3, 2], + [3, 3, 3], +] + + +@pytest.mark.parametrize( + "ar_steps,num_past_forcing_steps,num_future_forcing_steps", + SCENARIOS, +) +def test_time_slicing_analysis( + ar_steps, num_past_forcing_steps, num_future_forcing_steps +): + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(STATE_VALUES)) + assert len(STATE_VALUES) == len(FORCING_VALUES) == len(time_values) + + datastore = SinglePointDummyDatastore( + state_data=STATE_VALUES, + forcing_data=FORCING_VALUES, + time_values=time_values, + is_forecast=False, + ) + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + ar_steps=ar_steps, + num_future_forcing_steps=num_future_forcing_steps, + num_past_forcing_steps=num_past_forcing_steps, + standardize=False, + ) + + sample = dataset[0] + + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] + + # Some scenarios for the human reader + expected_init_states = [0, 1] + if ar_steps == 3: + expected_target_states = [2, 3, 4] + if num_past_forcing_steps == num_future_forcing_steps == 0: + expected_forcing_values = [[12], [13], [14]] + elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0: + expected_forcing_values = [[11, 12], [12, 13], [13, 14]] + elif num_past_forcing_steps == 2 and num_future_forcing_steps == 0: + expected_forcing_values = [[10, 11, 12], [11, 12, 13], [12, 13, 14]] + elif num_past_forcing_steps == 3 and num_future_forcing_steps == 0: + expected_init_states = [1, 2] + expected_target_states = [3, 4, 5] + expected_forcing_values = [ + [10, 11, 12, 13], + [11, 12, 13, 14], + [12, 13, 14, 15], + ] + + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps for all scenarios + expected_init_states = STATE_VALUES[offset:init_idx] + expected_target_states = STATE_VALUES[init_idx : init_idx + ar_steps] + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + forcing_window = FORCING_VALUES[start_idx:end_idx] + expected_forcing_values.append(forcing_window) + + # init_states: (2, N_grid, d_features) + # target_states: (ar_steps, N_grid, d_features) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) + # target_times: (ar_steps,) + + # Adjust assertions to use computed expected values + assert init_states.shape == (INIT_STEPS, 1, 1) + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + + assert target_states.shape == (ar_steps, 1, 1) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) + + assert forcing.shape == ( + ar_steps, + 1, + total_forcing_window, # No time deltas for interior forcing + ) + + # Extract the forcing values from the tensor (excluding time deltas) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) + + +@pytest.mark.parametrize( + "ar_steps,num_past_forcing_steps,num_future_forcing_steps", + SCENARIOS, +) +def test_time_slicing_forecast( + ar_steps, num_past_forcing_steps, num_future_forcing_steps +): + # Constants for forecast data + ANALYSIS_TIMES = np.datetime64("2020-01-01") + np.arange( + len(STATE_VALUES_FORECAST) + ) + ELAPSED_FORECAST_DURATION = np.timedelta64(0, "D") + np.arange( + # Retrieving the first analysis_time + len(FORCING_VALUES_FORECAST[0]) + ) + # Create a dummy datastore with forecast data + time_values = (ANALYSIS_TIMES, ELAPSED_FORECAST_DURATION) + datastore = SinglePointDummyDatastore( + state_data=STATE_VALUES_FORECAST, + forcing_data=FORCING_VALUES_FORECAST, + time_values=time_values, + is_forecast=True, + ) + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + + # Test the dataset length + assert len(dataset) == len(ANALYSIS_TIMES) + + sample = dataset[0] + + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] + + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + # Retrieving the first analysis_time + expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx] + expected_target_states = STATE_VALUES_FORECAST[0][ + init_idx : init_idx + ar_steps + ] + + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + # Retrieving the analysis_time relevant for forcing-windows (i.e. + # the first analysis_time after the 2 init_steps) + forcing_window = FORCING_VALUES_FORECAST[INIT_STEPS][start_idx:end_idx] + expected_forcing_values.append(forcing_window) + + # init_states: (2, N_grid, d_features) + # target_states: (ar_steps, N_grid, d_features) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) + # target_times: (ar_steps,) + + # Assertions + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) + + # Verify the shape of the forcing data + expected_forcing_shape = ( + ar_steps, # Number of AR steps + 1, # Number of grid points + total_forcing_window, # Total number of forcing steps in the window + # no time deltas for interior forcing + ) + assert forcing.shape == expected_forcing_shape + + # Extract the forcing values from the tensor (excluding time deltas) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +@pytest.mark.parametrize( + "subsample_config", + [ + # (interior_subsample, boundary_subsample, ar_steps) + (1, 1, 1), # Base case - no subsampling + (2, 1, 1), # Interior subsampling only + (1, 2, 1), # Boundary subsampling only + (2, 2, 1), # Equal subsampling + (2, 2, 2), # More AR steps + ], +) +def test_dataset_subsampling( + datastore_name, datastore_boundary_name, subsample_config +): + """Test that WeatherDataset handles different subsample steps correctly for + interior and boundary data. + + The test checks: + 1. Dataset creation succeeds with different subsample configurations + 2. Time differences between consecutive states match subsample steps + 3. Shapes of returned tensors are correct + 4. We can access the last item without errors + """ + interior_subsample, boundary_subsample, ar_steps = subsample_config + + datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + + # Configure dataset with subsampling + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=datastore_boundary, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, + interior_subsample_step=interior_subsample, + boundary_subsample_step=boundary_subsample, + ) + + # Get first sample + init_states, target_states, forcing, boundary, target_times = dataset[0] + + # Check shapes + assert init_states.shape[0] == 2 # Always 2 initial states + assert target_states.shape[0] == ar_steps + + # Check time differences + times = target_times.numpy() + for i in range(1, len(times)): + time_delta = np.timedelta64(times[i] - times[i - 1], "ns") + expected_hours = interior_subsample * datastore.step_length + np.testing.assert_equal( + time_delta.astype("timedelta64[h]").astype(int), expected_hours + ) + + # Verify boundary data timesteps if present + if boundary is not None: + assert boundary.shape[0] == ar_steps + # Each boundary window should have: + # (num_past + num_future + 1) timesteps * features * 2 (for time deltas) + expected_boundary_features = ( + datastore_boundary.get_num_data_vars("forcing") + 1 + ) * ( + 1 + 1 + 1 + ) # past + future + current + assert boundary.shape[2] == expected_boundary_features + + # Verify we can access the last item + dataset[len(dataset) - 1] + + +@pytest.mark.parametrize( + "num_past_steps,num_future_steps,interior_step,boundary_step", + [ + (1, 1, 1, 1), # Base case, no subsampling + (2, 1, 1, 1), # More past steps, no subsampling + (1, 2, 1, 1), # More future steps, no subsampling + (2, 2, 1, 1), # Equal past/future, no subsampling + (1, 1, 1, 2), # Basic case with boundary subsampling + (2, 2, 1, 2), # Equal past/future with boundary subsampling + (1, 1, 2, 1), # Basic case with interior subsampling + (2, 2, 2, 1), # Equal past/future with interior subsampling + (1, 1, 2, 2), # Both subsamplings + ], +) +def test_time_deltas_in_boundary_data( + num_past_steps, num_future_steps, interior_step, boundary_step +): + """Test that time deltas are correctly calculated for boundary data. + + This test verifies: + 1. Time deltas are included in boundary data + 2. Time deltas are in units of state timesteps + 3. Time deltas are correctly calculated relative to current timestep + 4. Time steps scale correctly with subsampling + """ + # Create dummy data with known timesteps (3 hour intervals for interior) + time_values_interior = np.datetime64("2020-01-01") + np.arange( + 20 + ) * np.timedelta64(3, "h") + # 6 hour intervals for boundary + time_values_boundary = np.datetime64("2020-01-01") + np.arange( + 10 + ) * np.timedelta64(6, "h") + + time_step_ratio = ( + 6 / 3 + ) # Boundary step is 6 hours, interior step is 3 hours + + state_data = np.arange(20) + forcing_data = np.arange(20, 40) + boundary_data = np.arange(10) # Fewer points due to larger time step + + interior_datastore = SinglePointDummyDatastore( + state_data=state_data, + forcing_data=forcing_data, + time_values=time_values_interior, + is_forecast=False, + ) + + boundary_datastore = BoundaryDummyDatastore( + state_data=boundary_data, + forcing_data=boundary_data + 10, + time_values=time_values_boundary, + is_forecast=False, + ) + + dataset = WeatherDataset( + datastore=interior_datastore, + datastore_boundary=boundary_datastore, + split="train", + ar_steps=2, + num_past_boundary_steps=num_past_steps, + num_future_boundary_steps=num_future_steps, + interior_subsample_step=interior_step, + boundary_subsample_step=boundary_step, + standardize=False, + ) + + # Get first sample + _, _, _, boundary, target_times = dataset[0] + + # Extract time deltas from boundary data + # Time deltas are the last features in the boundary tensor + window_size = num_past_steps + num_future_steps + 1 + time_deltas = boundary[0, 0, -window_size:].numpy() + + # Expected time deltas in state timesteps, adjusted for boundary subsampling + # For each window position, calculate expected offset from current time + expected_deltas = ( + np.arange(-num_past_steps, num_future_steps + 1) + * boundary_step + * time_step_ratio + ) + + # Verify time deltas match expected values + np.testing.assert_array_equal(time_deltas, expected_deltas) + + # Calculate expected hours offset from current time + # Each state timestep is 3 hours, scale by boundary step + expected_hours = expected_deltas * boundary_datastore.step_length + time_delta_hours = time_deltas * boundary_datastore.step_length + + # Verify time delta hours match expected values + np.testing.assert_array_equal(time_delta_hours, expected_hours) + + # Verify relative hour differences between timesteps + expected_hour_diff = ( + boundary_step * boundary_datastore.step_length * time_step_ratio + ) + hour_diffs = np.diff(time_delta_hours) + np.testing.assert_array_equal( + hour_diffs, [expected_hour_diff] * (len(time_delta_hours) - 1) + ) + + # Extract boundary times and verify they match expected hours + for i in range(len(target_times)): + window_start_idx = i * (window_size * 2) + window_end_idx = window_start_idx + window_size + boundary_times = boundary[i, 0, window_start_idx:window_end_idx].numpy() + boundary_time_diffs = ( + np.diff(boundary_times) * boundary_datastore.step_length + ) + expected_diff = boundary_step * boundary_datastore.step_length + np.testing.assert_array_equal( + boundary_time_diffs, [expected_diff] * (len(boundary_times) - 1) + ) diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..4773bbf3 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,131 @@ +# Standard library +from pathlib import Path + +# Third-party +import pytest +import pytorch_lightning as pl +import torch +import wandb + +# First-party +from neural_lam import config as nlconfig +from neural_lam.build_rectangular_graph import build_graph_from_archetype +from neural_lam.datastore import DATASTORES +from neural_lam.datastore.base import BaseRegularGridDatastore +from neural_lam.models.graph_lam import GraphLAM +from neural_lam.weather_dataset import WeatherDataModule +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + get_test_mesh_dist, + init_datastore_boundary_example, + init_datastore_example, +) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): + datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_name} as it is not a regular " + "grid datastore." + ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." + ) + + if torch.cuda.is_available(): + device_name = "cuda" + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s + else: + device_name = "cpu" + + trainer = pl.Trainer( + max_epochs=1, + deterministic=True, + accelerator=device_name, + # XXX: `devices` has to be set to 2 otherwise + # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails + # because it expects to aggregate over multiple devices + devices=2, + log_every_n_steps=1, + ) + + flat_graph_name = "1level" + + graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name + + if not graph_dir_path.exists(): + build_graph_from_archetype( + datastore=datastore, + datastore_boundary=datastore_boundary, + graph_name=flat_graph_name, + archetype="keisler", + mesh_node_distance=get_test_mesh_dist( + datastore, datastore_boundary + ), + ) + + data_module = WeatherDataModule( + datastore=datastore, + datastore_boundary=datastore_boundary, + ar_steps_train=3, + ar_steps_eval=5, + standardize=True, + batch_size=2, + num_workers=1, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, + ) + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + # XXX: this should be superfluous when we have already defined the + # model object no? + graph_name = flat_graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 2 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1, 3] + metrics_watch = [] + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + shared_grid_embedder = False + + model_args = ModelArgs() + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + model = GraphLAM( + args=model_args, + datastore=datastore, + datastore_boundary=datastore_boundary, + config=config, + ) # noqa + + wandb.init() + trainer.fit(model=model, datamodule=data_module) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..ab978887 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,30 @@ +# Standard library +import copy + +# Third-party +import torch + +# First-party +from neural_lam.utils import BufferList + + +def test_bufferlist_idiv(): + """Test in-place division of bufferlist""" + + tensors_to_buffer = [i * torch.ones(5) for i in range(3)] + tensors_for_checking = copy.deepcopy(tensors_to_buffer) + blist = BufferList(tensors_to_buffer) + + divisor = 5.0 + div_tensors = [ten / divisor for ten in tensors_for_checking] + div_blist = copy.deepcopy(blist) + div_blist /= divisor + for bl_ten, check_ten in zip(div_tensors, div_blist): + torch.testing.assert_allclose(bl_ten, check_ten) + + multiplier = 2.0 + mult_tensors = [ten * multiplier for ten in tensors_for_checking] + mult_blist = copy.deepcopy(blist) + mult_blist *= multiplier + for bl_ten, check_ten in zip(mult_tensors, mult_blist): + torch.testing.assert_allclose(bl_ten, check_ten)