Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design mllam-verification package #2

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions README.md
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -1,2 +1,70 @@
# mllam-verification
Verification of neural-lam, e.g. performance of the model relative to the truth, persistence, etc.

## General API patterns

- Every plot function should have ax = None as input argument. If no axes is provided, a new figure with appropriate settings/features (e.g. coastlines for maps) should be created. Otherwise, the plot should be added to the provided axes on top of the existing plot with a zorder = 10*n, where n is an integer. One can also specify zorder as input argument to the plot function to place the plot at a specific place in the plot hierarchy.
- Every plot function should have an include_persistence input argument, that defaults to True if it is possible to add persistence for the given plot type, and False if not. If include_persistence = True , but the plot doesn't support plotting the persistence an error should be raised.
- Every plot function should take the metric to compute and plot as input argument.
- The functions shall be callable from JupyterNotebook/Python script and the CLI.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment by @observingClouds imported from confluence

CLI sounds complicated, especially to defined variables, datasets etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so we should not have a CLI?

- The top-level plot functions should be named according to the plot they produce. They should not directly contain the logic to actually compute the metrics, but instead call other comput functions to do that.
- The package shall support specifying plot settings, output path for saving plots, etc. in a config file
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this, or should people just adjust the axes/figure by themselves after calling the plot functions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment by @observingClouds imported from confluence:

I think the package should stick to the minimum and does not provide any config parser. Just simple functions with input arguments and axes or values as return values, depending on the functions aim.


### Example on python functions

In [mllam_verification.plot](https://github.com/mllam/mllam-verification/pull/2/files#diff-15c7e6b996e82bcc0f02abdd0aa702dcc8d0afae552cec7872cbc1005081c896) there is an example on a plot function that plots the single-metric-timeseries plot

This plot function will call a calculate_{metric} function (located in the metrics.py module), that could look like (for metric=rmse) [this](https://github.com/mllam/mllam-verification/pull/2/files#diff-83955882ad971de54d93c75790537d89f01aaea7969c6e31109456d248aa1a20) with a persistence calculation function as shown later in that file.

These two functions will make use of a RMSE function located in the [statistics.py](https://github.com/mllam/mllam-verification/pull/2/files#diff-938dea8677e64061150529b5d1ca7753c260b27568d2d7be0a911edefe3fbd6a) module. The statistics.py functions will call functions from the `scores` python package where possible and add relevant cf compliant cell_methods.

## Python API
The mllam_verification package should be structured according to this directory structure. As an example, the above plot function plot_single_metric_timeseries will be located in mllam_verification/plot.py .

```
.
├── mllam_verification
│ ├── operations
│ │ ├── __init__.py
│ │ ├── dataset_manipulation.py # Contains functions for dataset manipulation e.g. aligning shapes etc.
│ │ ├── loading.py # Contains functions for loading data
│ │ ├── saving.py # Contains functions for saving data and plots
│ │ ├── plot_utils.py # Contains utility functions for plotting e.g. color maps, figure instanciation and formating etc.
│ │ ├── statistics.py # Contains functions for computing statistics e.g. mean, std, etc.
│ │ └── metrics.py # Contains functions for computing metrics e.g. mean squared error, etc.
│ ├── __init__.py
│ ├── __main__.py # Entry point of the package
│ ├── argument_parser.py # Contains CLI argument parser
│ ├── config.py # Contains config file parser
│ └── plot.py # Main script for producing plots
└── tests
├── conftest.py
├── unit
│ ├── conftest.py
│ └── ...
└── integration
├── conftest.py
└── ...
├── pdm.lock
├── pyproject.toml
├── example.yaml # Example config file
└── README.md
```

## CLI API
The package shall be callable with the following arguments
```bash
mllam_verification -r/--reference /path/to/ds_reference -p/--prediction /path/to/ds_prediction -m/--metric <metric_name> --var <variable-name(s)> --plot <name-of-plot(s)> --save-plots --save-dir /path/to/output
```

## Supported plots
The following is a first draft on the plots we want to make available in the mllam-verification package and what they support:
| Name | Plot function name | Example | Grouped | Elapsed | UTC | Multi | Multi model | Multi variable | Point | Regular |
|-----------------------------|-------------------------------------|------------------|---------|---------|-----|-------|-------------|---------------|-------|---------|
| Single metric timeseries | `plot_single_metric_timeseries` | ![single_metric_timeseries_example](./docs/_images/single_metric_timeseries_example.png) | ✅¹| ✅¹| ✅ | ✅ | ❌ | ❌ | ✅ | ✅ |
| Single metric hovmöller | `plot_single_metric_hovmoller` | ![single_metric_hovmoller_example](./docs/_images/single_metric_hovmoller_example.png) | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ |
| Single metric gridded map | `plot_single_metric_gridded_map` | ![single_metric_gridded_map](./docs/_images/single_metric_gridded_map_example.png) | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ |
| Single metric point map | `plot_single_metric_point_map` | ![single_metric_point_map](./docs/_images/single_metric_point_map_example.png) | ✅ | ❌ | ❌ | ✅² | ✅² | ❌ | ✅ | ✅ |

¹ without persistence\
² maybe not a good idea e.g. if points overlap in grid.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_images/single_metric_hovmoller_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_images/single_metric_point_map_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_images/single_metric_timeseries_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions mllam_verification/operations/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import List

import xarray as xr

from .statistics import rmse


def calculate_rmse(
ds_reference: xr.Dataset,
ds_prediction: xr.Dataset,
variable: str,
reduce_dims: List[str],
include_persistence=False,
) -> xr.Dataset:
"""Calculate RMSE between prediction and reference datasets.

RMSE: Root Mean Square Error

If specified, calculate the error relative to persistence too.
The calculation is done only for the specified variable.
The input datasets are assumed to have the following specifications:

Dimensions: [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]
Data variables:
- variable1 [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
- variable2 [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
- ...
- variableN [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
Coordinates:
- start_time:
the analysis time as a datetime object
- elapsed_forecast_duration:
the elapsed forecast duration as a timedelta object
- reduce_dim1:
one of the dimensions to reduce along when calculating the persistence
- reduce_dim2:
one of the dimensions to reduce along when calculating the persistence
- ...

The error is averaged along the start_time dimension of the datasets.
The error is returned as a dataset with the following specification:

Dimensions: [elapsed_forecast_duration]
Data variables:
- <variable>_rmse [elapsed_forecast_duration]:
the RMSE between the prediction and reference datasets
- <variable>_persistence [elapsed_forecast_duration], optional:
the persistence RMSE calculated based on the reference datasets
Coordinates:
- elapsed_forecast_duration:

Parameters:
-----------
ds_reference: xr.Dataset
The reference dataset to calculate global error against.
ds_prediction: xr.Dataset
The prediction dataset to calculate global error of.
variable: str
The variable to calculate the metric of.
reduce_dims: List[str]
The dimensions to reduce along when calculating the metric.
include_persistence: bool
Whether to calculate the error relative to persistence
"""
# Select the variable from the datasets
ds_reference = ds_reference[[variable]]
ds_prediction = ds_prediction[[variable]]

# Calculate the error and rename the variable
ds_metric = rmse(ds_prediction, ds_reference, reduce_dims=reduce_dims)
ds_metric = ds_metric.rename({variable: f"{variable}_rmse"})

# Calculate the persistence error and merge with the metric dataset
if include_persistence:
ds_persistence_metric = calculate_persistence_rmse(
ds_reference, variable, reduce_dims=reduce_dims
)
ds_metric = xr.merge([ds_metric, ds_persistence_metric])

# Take mean over all start times
ds_metric = ds_metric.mean("start_time")
# Update cell_methods attributes
for _, da_var in ds_metric.items():
da_var.attrs["cell_methods"] = " ".join(
[da_var.attrs["cell_methods"], "start_time: mean"]
)

return ds_metric


def calculate_persistence_rmse(
ds_reference: xr.Dataset, variable: str, reduce_dims: List[str]
) -> xr.Dataset:
"""Calculate the RMSE between the reference dataset and its persistence.

RMSE: Root Mean Square Error

The calculation is done only for the specified variable.
The input dataset is assumed to have the following specifications:

Dimensions: [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]
Data variables:
- variable1 [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
- variable2 [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
- ...
- variableN [start_time, elapsed_forecast_duration, reduce_dim1, reduce_dims2, ...]:
Coordinates:
- start_time:
the analysis time as a datetime object
- elapsed_forecast_duration:
the elapsed forecast duration as a timedelta object
- reduce_dim1:
one of the dimensions to reduce along when calculating the persistence
- reduce_dim2:
one of the dimensions to reduce along when calculating the persistence
- ...

The error is returned as a dataset with the following specification:

Dimensions: [elapsed_forecast_duration]
Data variables:
- <variable>_persistence [elapsed_forecast_duration]:
the persistence RMSE calculated based on the reference dataset
Coordinates:
- elapsed_forecast_duration:

Parameters:
-----------
ds_reference: xr.Dataset
The reference dataset to calculate persistence error against.
variable: str
The variable to calculate the persistence of.
reduce_dims: List[str]
The dimensions to reduce along when calculating the persistence.
"""
# Select the variable from the dataset
ds_reference = ds_reference[[variable]]

ds_persistence_reference = ds_reference.isel(
elapsed_forecast_duration=slice(1, None)
)
ds_persistence_prediction = ds_reference.isel(
elapsed_forecast_duration=slice(0, -1)
)
ds_persistence_prediction = ds_persistence_prediction.assign_coords(
elapsed_forecast_duration=ds_persistence_reference["elapsed_forecast_duration"]
)
ds_persistence_metric = rmse(
ds_prediction=ds_persistence_prediction,
ds_reference=ds_persistence_reference,
reduce_dims=reduce_dims,
)
ds_persistence_metric = ds_persistence_metric.rename(
{variable: f"{variable}_persistence"}
)

return ds_persistence_metric
23 changes: 23 additions & 0 deletions mllam_verification/operations/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import List

import scores as scc
import xarray as xr


def rmse(
ds_prediction: xr.Dataset, ds_reference: xr.Dataset, reduce_dims: List["str"]
) -> xr.Dataset:
"""Compute the root mean squared error across grid_index for all variables.

Args:
ds (xr.Dataset): Input dataset
Returns:
xr.Dataset: Dataset with the computed statistical variables
"""
ds_rmse = scc.rmse(ds_prediction, ds_reference, reduce_dims=reduce_dims)

# Update cell_methods attributes
for _, da_var in ds_rmse.items():
da_var.attrs["cell_methods"] = ",".join(reduce_dims) + ": root_mean_square"

return ds_rmse
46 changes: 46 additions & 0 deletions mllam_verification/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import matplotlib.pyplot as plt
import xarray as xr
import mllam_verification as mlverif


def plot_single_metric_timeseries(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit in doubt how we want to include the possibility to use either grouped, elapsed, UTC or Multiple times in this and other plot functions. Should we have an input parameter like time_type: str = Literal["grouped"|"elapsed"|"UTC"|"multi"] and then have an extra optional groupby: Optional[str] = None argument? Or should we have four different input arguments groupby: Optional[str] = None, elapsed: Optional[bool] = False, utc: Optional[bool] = False and multi: Optional[bool] = False? Or do you have other ideas?

ds_reference: xr.Dataset,
ds_prediction: xr.Dataset,
variable: str,
metric: str,
axes: plt.Axes = None,
include_persistence=True,
xarray_plot_kwargs: dict = {},
):
"""Plot a single-metric-timeseries diagram for a given variable and metric.

The metric is calculated from ds_reference and ds_prediction.

Parameters
----------
ds_reference : xr.Dataset
Reference dataset.
ds_prediction : xr.Dataset
Prediction dataset.
variable : str
Variable to calculate metric of.
metric : str
Metric to calculate.
axes : plt.Axes, optional
Axes to plot on, by default None
xarray_plot_kwargs : dict, optional
Additional arguments to pass to xarray's plot function, by default {}
"""

ds_metric = mlverif.calculate_{metric}(ds_reference, ds_prediction, variable, include_persistence=include_persistence)

if axes is None:
axes = mlverif.operations.plot_utils.get_axes(plot_type="timeseries")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just init a new figure directly here rather than putting that somewhere else, that would be more explicit

Suggested change
axes = mlverif.operations.plot_utils.get_axes(plot_type="timeseries")
fig, axes = plt.subplots()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought, that this step would also include e.g. setting up projection, coastline etc. if needed, so it might be more than just a oneliner.


ds_metric[metric].plot.line(ax=axes, **xarray_plot_kwargs)

if include_persistence:
ds_metric["persistence"].plot.line(ax=axes, **xarray_plot_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the metric calculation could add a new coordinate to the dataset called say data_source which could include names like [persistence, DINI_forecast, ...] (persistence would only be there if include_persistence==True, the other names could be set from dict keys if we pass in multiple prediction datasets), then plotting could be replaced with:

Suggested change
ds_metric[metric].plot.line(ax=axes, **xarray_plot_kwargs)
if include_persistence:
ds_metric["persistence"].plot.line(ax=axes, **xarray_plot_kwargs)
ds_metric[metric].plot.line(ax=axes, hue="data_source", **xarray_plot_kwargs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should need to select the metric variable here, either we compute the same metric for all variables and select a specific variable here, or only calculate the metric for a specific variable, in which case the return from the metric function could just be a data-array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding my comment above, maybe we should rather than using a dict require that the user does concatenation along a data_source dimension if the ds_prediction argument is to contain multiple sources

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should need to select the metric variable here, either we compute the same metric for all variables and select a specific variable here, or only calculate the metric for a specific variable, in which case the return from the metric function could just be a data-array

Agree, the [metric] part is not needed. I think we should stick with only calculating one metric for one variable and let the metric function return a data-array then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding my comment above, maybe we should rather than using a dict require that the user does concatenation along a data_source dimension if the ds_prediction argument is to contain multiple sources

We could do that yet. Not sure which one is the better design.



return axes