From d93c31b77572d605f4dabe15a9bff1b275a17852 Mon Sep 17 00:00:00 2001 From: Hauke Schulz <43613877+observingClouds@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:01:59 +0100 Subject: [PATCH 1/2] mypy typing for typed functions --- .pre-commit-config.yaml | 6 +++ neural_lam/config.py | 18 ++++---- neural_lam/create_graph.py | 9 ++-- neural_lam/datastore/__init__.py | 3 +- neural_lam/datastore/base.py | 22 +++++++--- neural_lam/datastore/mdp.py | 9 ++-- neural_lam/datastore/npyfilesmeps/store.py | 20 ++++++--- neural_lam/datastore/plot_example.py | 3 +- neural_lam/models/ar_model.py | 20 +++++---- neural_lam/vis.py | 6 +-- neural_lam/weather_dataset.py | 48 +++++++++++----------- tests/dummy_datastore.py | 5 ++- tests/test_time_slicing.py | 7 +++- 13 files changed, 104 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dfbf8b60..594ddde7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,3 +36,9 @@ repos: - id: flake8 description: Check Python code for correctness, consistency and adherence to best practices additional_dependencies: [Flake8-pyproject] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.14.1 + hooks: + - id: mypy + additional_dependencies: [types-PyYAML] + description: Check for type errors diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..fd4b4def 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -15,15 +15,6 @@ ) -class DatastoreKindStr(str): - VALID_KINDS = DATASTORES.keys() - - def __new__(cls, value): - if value not in cls.VALID_KINDS: - raise ValueError(f"Invalid datastore kind: {value}") - return super().__new__(cls, value) - - @dataclasses.dataclass class DatastoreSelection: """ @@ -31,7 +22,7 @@ class DatastoreSelection: Attributes ---------- - kind : DatastoreKindStr + kind : str The kind of datastore to use, currently `mdp` or `npyfilesmeps` are implemented. config_path : str @@ -39,7 +30,12 @@ class DatastoreSelection: assumed to be relative to the configuration file for neural-lam. """ - kind: DatastoreKindStr + kind: str + + def __post_init__(self): + if self.kind not in DATASTORES: + raise ValueError(f"Datastore kind {self.kind} is not implemented") + config_path: str diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index ef979be3..5ea5758b 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -1,6 +1,7 @@ # Standard library import os from argparse import ArgumentParser +from typing import Optional # Third-party import matplotlib @@ -157,9 +158,9 @@ def prepend_node_index(graph, new_index): def create_graph( graph_dir_path: str, xy: np.ndarray, - n_max_levels: int, - hierarchical: bool, - create_plot: bool, + n_max_levels: Optional[int] = None, + hierarchical: Optional[bool] = False, + create_plot: Optional[bool] = False, ): """ Create graph components from `xy` grid coordinates and store in @@ -538,7 +539,7 @@ def create_graph( def create_graph_from_datastore( datastore: BaseRegularGridDatastore, output_root_path: str, - n_max_levels: int = None, + n_max_levels: Optional[int] = None, hierarchical: bool = False, create_plot: bool = False, ): diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py index 40e683ac..dead7713 100644 --- a/neural_lam/datastore/__init__.py +++ b/neural_lam/datastore/__init__.py @@ -9,7 +9,8 @@ ] DATASTORES = { - datastore.SHORT_NAME: datastore for datastore in DATASTORE_CLASSES + datastore.SHORT_NAME: datastore # type: ignore + for datastore in DATASTORE_CLASSES } diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index f0291657..70f4c053 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -5,7 +5,7 @@ import functools from functools import cached_property from pathlib import Path -from typing import List, Union +from typing import List, Optional, Union # Third-party import cartopy.crs as ccrs @@ -215,7 +215,10 @@ def _standardize_datarray( @abc.abstractmethod def get_dataarray( - self, category: str, split: str, standardize: bool = False + self, + category: str, + split: Optional[str], + standardize: bool = False, ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given @@ -275,7 +278,7 @@ def boundary_mask(self) -> xr.DataArray: pass @abc.abstractmethod - def get_xy(self, category: str) -> np.ndarray: + def get_xy(self, category: str, stacked: bool) -> np.ndarray: """ Return the x, y coordinates of the dataset as a numpy arrays for a given category of data. @@ -284,6 +287,12 @@ def get_xy(self, category: str) -> np.ndarray: ---------- category : str The category of the dataset (state/forcing/static). + stacked : bool + Whether to stack the x, y coordinates. The parameter `stacked` has + been introduced in this class. Parent class `BaseDatastore` has the + same method signature but without the `stacked` parameter. Defaults + to `True` to match the behaviour of `BaseDatastore.get_xy()` which + always returns the coordinates stacked. Returns ------- @@ -364,7 +373,9 @@ def state_feature_weights_values(self) -> List[float]: pass @functools.lru_cache - def expected_dim_order(self, category: str = None) -> tuple[str]: + def expected_dim_order( + self, category: Optional[str] = None + ) -> tuple[str, ...]: """ Return the expected dimension order for the dataarray or dataset returned by `get_dataarray` for the given category of data. The @@ -471,7 +482,7 @@ def grid_shape_state(self) -> CartesianGridShape: pass @abc.abstractmethod - def get_xy(self, category: str, stacked: bool = True) -> np.ndarray: + def get_xy(self, category: str, stacked: bool) -> np.ndarray: """Return the x, y coordinates of the dataset. Parameters @@ -574,7 +585,6 @@ def stack_grid_coords( return da_or_ds_stacked.transpose(*dim_order) @property - @functools.lru_cache def num_grid_points(self) -> int: """Return the number of grid points in the dataset. diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 01d2b12b..e89efae5 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -3,7 +3,7 @@ import warnings from functools import cached_property from pathlib import Path -from typing import List +from typing import List, Optional, Union # Third-party import cartopy.crs as ccrs @@ -220,8 +220,11 @@ def get_num_data_vars(self, category: str) -> int: return len(self.get_vars_names(category)) def get_dataarray( - self, category: str, split: str, standardize: bool = False - ) -> xr.DataArray: + self, + category: str, + split: Optional[str], + standardize: bool = False, + ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 8f926f7e..481b2d0f 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -9,7 +9,7 @@ import warnings from functools import cached_property from pathlib import Path -from typing import List +from typing import List, Optional # Third-party import cartopy.crs as ccrs @@ -200,7 +200,7 @@ def config(self) -> NpyDatastoreConfig: return self._config def get_dataarray( - self, category: str, split: str, standardize: bool = False + self, category: str, split: Optional[str], standardize: bool = False ) -> DataArray: """ Get the data array for the given category and split of data. If the @@ -313,7 +313,10 @@ def get_dataarray( return da def _get_single_timeseries_dataarray( - self, features: List[str], split: str, member: int = None + self, + features: List[str], + split: Optional[str] = None, + member: Optional[int] = None, ) -> DataArray: """ Get the data array spanning the complete time series for a given set of @@ -376,7 +379,10 @@ def _get_single_timeseries_dataarray( add_feature_dim = False features_vary_with_analysis_time = True feature_dim_mask = None - if features == self.get_vars_names(category="state"): + if ( + features == self.get_vars_names(category="state") + and split is not None + ): filename_format = STATE_FILENAME_FORMAT file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] # only select one member for now @@ -388,12 +394,14 @@ def _get_single_timeseries_dataarray( len(features) + n_to_drop, dtype=bool ) feature_dim_mask[self._remove_state_features_with_index] = False - elif features == ["toa_downwelling_shortwave_flux"]: + elif ( + features == ["toa_downwelling_shortwave_flux"] and split is not None + ): filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT file_dims = ["elapsed_forecast_duration", "y", "x", "feature"] add_feature_dim = True fp_samples = self.root_path / "samples" / split - elif features == ["open_water_fraction"]: + elif features == ["open_water_fraction"] and split is not None: filename_format = OPEN_WATER_FILENAME_FORMAT file_dims = ["y", "x", "feature"] add_feature_dim = True diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py index 2d477271..4f61ac7e 100644 --- a/neural_lam/datastore/plot_example.py +++ b/neural_lam/datastore/plot_example.py @@ -163,6 +163,7 @@ def _parse_dict(arg_str): selection = dict(args.selection) index_selection = dict(args.index_selection) + datastore_kind = args.datastore_kind # check that column dimension is not in the selection if args.col_dim.format(category=args.category) in selection: @@ -173,7 +174,7 @@ def _parse_dict(arg_str): ) datastore = init_datastore( - datastore_kind=args.datastore_kind, + datastore_kind=datastore_kind, config_path=args.datastore_config_path, ) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index d19805f1..9649d8ec 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,7 +1,7 @@ # Standard library import os import warnings -from typing import List, Union +from typing import Any, Dict, List # Third-party import matplotlib.pyplot as plt @@ -43,6 +43,8 @@ def __init__( da_static_features = datastore.get_dataarray( category="static", split=None, standardize=True ) + if da_static_features is None: + raise ValueError("Static features are required for ARModel") da_state_stats = datastore.get_standardization_dataarray( category="state" ) @@ -127,10 +129,10 @@ def __init__( "interior_mask", 1.0 - self.boundary_mask, persistent=False ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { + self.val_metrics: Dict[str, List] = { "mse": [], } - self.test_metrics = { + self.test_metrics: Dict[str, List] = { "mse": [], "mae": [], } @@ -145,12 +147,12 @@ def __init__( self.plotted_examples = 0 # For storing spatial loss maps during evaluation - self.spatial_loss_maps = [] + self.spatial_loss_maps: List[Any] = [] def _create_dataarray_from_tensor( self, tensor: torch.Tensor, - time: Union[int, List[int]], + time: torch.Tensor, split: str, category: str, ) -> xr.DataArray: @@ -166,9 +168,9 @@ def _create_dataarray_from_tensor( The tensor to convert to a `xr.DataArray` with dimensions [time, grid_index, feature]. The tensor will be copied to the CPU if it is not already there. - time : Union[int,List[int]] - The time index or indices for the data, given as integers or a list - of integers representing epoch time in nanoseconds. The ints will be + time : torch.Tensor + The time index or indices for the data, given as tensor representing + epoch time in nanoseconds. The tensor will be copied to the CPU memory if they are not already there. split : str The split of the data, either 'train', 'val', or 'test' @@ -181,7 +183,7 @@ def _create_dataarray_from_tensor( weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor.cpu().numpy(), time=time, category=category + tensor=tensor, time=time, category=category ) return da diff --git a/neural_lam/vis.py b/neural_lam/vis.py index d6b57f88..48b8563a 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -67,8 +67,8 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, - da_prediction: xr.DataArray = None, - da_target: xr.DataArray = None, + da_prediction: xr.DataArray, + da_target: xr.DataArray, title=None, vrange=None, ): @@ -82,7 +82,7 @@ def plot_prediction( if vrange is None: vmin = min(da_prediction.min(), da_target.min()) vmax = max(da_prediction.max(), da_target.max()) - else: + elif vrange is not None: vmin, vmax = vrange extent = datastore.get_xy_extent("state") diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b5f85580..b3aac82c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -43,11 +43,11 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, - split="train", - ar_steps=3, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - standardize=True, + split: str = "train", + ar_steps: int = 3, + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + standardize: bool = True, ): super().__init__() @@ -66,7 +66,7 @@ def __init__( # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples - if self.__len__() <= 0: + if self.__len__() <= 0 and self.da_state is not None: raise ValueError( "The provided datastore only provides " f"{len(self.da_state.time)} total time steps, which is too few " @@ -86,13 +86,14 @@ def __init__( expected_dim_order = self.datastore.expected_dim_order( category=part ) - if da.dims != expected_dim_order: - raise ValueError( - f"The dimension order of the `{part}` data ({da.dims}) " - f"does not match the expected dimension order " - f"({expected_dim_order}). Maybe you forgot to transpose " - "the data in `BaseDatastore.get_dataarray`?" - ) + if da is not None: + if da.dims != expected_dim_order: + raise ValueError( + f"The dimension order of the `{part}` data ({da.dims}) " + f"does not match the expected dimension order " + f"({expected_dim_order}). Maybe you forgot to " + "transpose the data in `BaseDatastore.get_dataarray`?" + ) # Set up for standardization # TODO: This will become part of ar_model.py soon! @@ -553,7 +554,7 @@ def _is_listlike(obj): raise ValueError( "Expected a single time for a 2D tensor with assumed " "dimensions (grid_index, {category}_feature), but got " - f"{len(time)} times" + f"{len(time)} times" # type: ignore ) elif len(tensor.shape) == 3: add_time_as_dim = True @@ -606,13 +607,13 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, - ar_steps_train=3, - ar_steps_eval=25, - standardize=True, - num_past_forcing_steps=1, - num_future_forcing_steps=1, - batch_size=4, - num_workers=16, + ar_steps_train: int = 3, + ar_steps_eval: int = 25, + standardize: bool = True, + num_past_forcing_steps: int = 1, + num_future_forcing_steps: int = 1, + batch_size: int = 4, + num_workers: int = 16, ): super().__init__() self._datastore = datastore @@ -622,16 +623,15 @@ def __init__( self.ar_steps_eval = ar_steps_eval self.standardize = standardize self.batch_size = batch_size - self.num_workers = num_workers + self.num_workers: int = num_workers self.train_dataset = None self.val_dataset = None self.test_dataset = None + self.multiprocessing_context: Union[str, None] = None if num_workers > 0: # default to spawn for now, as the default on linux "fork" hangs # when using dask (which the npyfilesmeps datastore uses) self.multiprocessing_context = "spawn" - else: - self.multiprocessing_context = None def setup(self, stage=None): if stage == "fit" or stage is None: diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 0c76bca8..1320d742 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -3,7 +3,7 @@ import tempfile from functools import cached_property from pathlib import Path -from typing import List, Union +from typing import List, Optional, Tuple, Union # Third-party import isodate @@ -125,6 +125,7 @@ def __init__( # Define dimensions and create random data dims = ["grid_index", f"{category}_feature"] + shape: Tuple[int, ...] if category != "static": dims.append("time") shape = (n_grid_points, n, n_timesteps) @@ -301,7 +302,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: return ds_standardization def get_dataarray( - self, category: str, split: str, standardize: bool = False + self, category: str, split: Optional[str], standardize: bool = False ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 29161505..4290f7fa 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -1,3 +1,6 @@ +# Standard library +from pathlib import Path + # Third-party import numpy as np import pytest @@ -10,10 +13,10 @@ class SinglePointDummyDatastore(BaseDatastore): step_length = 1 - config = None + config = {} coords_projection = None num_grid_points = 1 - root_path = None + root_path = Path("dummy") def __init__(self, time_values, state_data, forcing_data, is_forecast): self._time_values = np.array(time_values) From a18a5a9af07f2894043ad9055428bfeaf3e171d6 Mon Sep 17 00:00:00 2001 From: Hauke Schulz <43613877+observingClouds@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:26:57 +0100 Subject: [PATCH 2/2] add further stubs --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 594ddde7..772a6a14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,5 +40,5 @@ repos: rev: v1.14.1 hooks: - id: mypy - additional_dependencies: [types-PyYAML] + additional_dependencies: [types-PyYAML, types-Pillow, types-tqdm] description: Check for type errors