Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: mypy testing #113

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ repos:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
additional_dependencies: [Flake8-pyproject]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-Pillow, types-tqdm]
description: Check for type errors
18 changes: 7 additions & 11 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,27 @@
)


class DatastoreKindStr(str):
VALID_KINDS = DATASTORES.keys()

def __new__(cls, value):
if value not in cls.VALID_KINDS:
raise ValueError(f"Invalid datastore kind: {value}")
return super().__new__(cls, value)


@dataclasses.dataclass
class DatastoreSelection:
"""
Configuration for selecting a datastore to use with neural-lam.

Attributes
----------
kind : DatastoreKindStr
kind : str
The kind of datastore to use, currently `mdp` or `npyfilesmeps` are
implemented.
config_path : str
The path to the configuration file for the selected datastore, this is
assumed to be relative to the configuration file for neural-lam.
"""

kind: DatastoreKindStr
kind: str

def __post_init__(self):
if self.kind not in DATASTORES:
raise ValueError(f"Datastore kind {self.kind} is not implemented")

config_path: str


Expand Down
9 changes: 5 additions & 4 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard library
import os
from argparse import ArgumentParser
from typing import Optional

# Third-party
import matplotlib
Expand Down Expand Up @@ -157,9 +158,9 @@ def prepend_node_index(graph, new_index):
def create_graph(
graph_dir_path: str,
xy: np.ndarray,
n_max_levels: int,
hierarchical: bool,
create_plot: bool,
n_max_levels: Optional[int] = None,
hierarchical: Optional[bool] = False,
create_plot: Optional[bool] = False,
):
"""
Create graph components from `xy` grid coordinates and store in
Expand Down Expand Up @@ -538,7 +539,7 @@ def create_graph(
def create_graph_from_datastore(
datastore: BaseRegularGridDatastore,
output_root_path: str,
n_max_levels: int = None,
n_max_levels: Optional[int] = None,
hierarchical: bool = False,
create_plot: bool = False,
):
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/datastore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
]

DATASTORES = {
datastore.SHORT_NAME: datastore for datastore in DATASTORE_CLASSES
datastore.SHORT_NAME: datastore # type: ignore
for datastore in DATASTORE_CLASSES
}


Expand Down
22 changes: 16 additions & 6 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
from functools import cached_property
from pathlib import Path
from typing import List, Union
from typing import List, Optional, Union

# Third-party
import cartopy.crs as ccrs
Expand Down Expand Up @@ -215,7 +215,10 @@ def _standardize_datarray(

@abc.abstractmethod
def get_dataarray(
self, category: str, split: str, standardize: bool = False
self,
category: str,
split: Optional[str],
standardize: bool = False,
) -> Union[xr.DataArray, None]:
"""
Return the processed data (as a single `xr.DataArray`) for the given
Expand Down Expand Up @@ -275,7 +278,7 @@ def boundary_mask(self) -> xr.DataArray:
pass

@abc.abstractmethod
def get_xy(self, category: str) -> np.ndarray:
def get_xy(self, category: str, stacked: bool) -> np.ndarray:
"""
Return the x, y coordinates of the dataset as a numpy arrays for a
given category of data.
Expand All @@ -284,6 +287,12 @@ def get_xy(self, category: str) -> np.ndarray:
----------
category : str
The category of the dataset (state/forcing/static).
stacked : bool
Whether to stack the x, y coordinates. The parameter `stacked` has
been introduced in this class. Parent class `BaseDatastore` has the
same method signature but without the `stacked` parameter. Defaults
to `True` to match the behaviour of `BaseDatastore.get_xy()` which
always returns the coordinates stacked.

Returns
-------
Expand Down Expand Up @@ -364,7 +373,9 @@ def state_feature_weights_values(self) -> List[float]:
pass

@functools.lru_cache
def expected_dim_order(self, category: str = None) -> tuple[str]:
def expected_dim_order(
self, category: Optional[str] = None
) -> tuple[str, ...]:
"""
Return the expected dimension order for the dataarray or dataset
returned by `get_dataarray` for the given category of data. The
Expand Down Expand Up @@ -471,7 +482,7 @@ def grid_shape_state(self) -> CartesianGridShape:
pass

@abc.abstractmethod
def get_xy(self, category: str, stacked: bool = True) -> np.ndarray:
def get_xy(self, category: str, stacked: bool) -> np.ndarray:
"""Return the x, y coordinates of the dataset.

Parameters
Expand Down Expand Up @@ -574,7 +585,6 @@ def stack_grid_coords(
return da_or_ds_stacked.transpose(*dim_order)

@property
@functools.lru_cache
def num_grid_points(self) -> int:
"""Return the number of grid points in the dataset.

Expand Down
9 changes: 6 additions & 3 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from functools import cached_property
from pathlib import Path
from typing import List
from typing import List, Optional, Union

# Third-party
import cartopy.crs as ccrs
Expand Down Expand Up @@ -220,8 +220,11 @@ def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category))

def get_dataarray(
self, category: str, split: str, standardize: bool = False
) -> xr.DataArray:
self,
category: str,
split: Optional[str],
standardize: bool = False,
) -> Union[xr.DataArray, None]:
"""
Return the processed data (as a single `xr.DataArray`) for the given
category of data and test/train/val-split that covers all the data (in
Expand Down
20 changes: 14 additions & 6 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from functools import cached_property
from pathlib import Path
from typing import List
from typing import List, Optional

# Third-party
import cartopy.crs as ccrs
Expand Down Expand Up @@ -200,7 +200,7 @@ def config(self) -> NpyDatastoreConfig:
return self._config

def get_dataarray(
self, category: str, split: str, standardize: bool = False
self, category: str, split: Optional[str], standardize: bool = False
) -> DataArray:
"""
Get the data array for the given category and split of data. If the
Expand Down Expand Up @@ -313,7 +313,10 @@ def get_dataarray(
return da

def _get_single_timeseries_dataarray(
self, features: List[str], split: str, member: int = None
self,
features: List[str],
split: Optional[str] = None,
member: Optional[int] = None,
) -> DataArray:
"""
Get the data array spanning the complete time series for a given set of
Expand Down Expand Up @@ -376,7 +379,10 @@ def _get_single_timeseries_dataarray(
add_feature_dim = False
features_vary_with_analysis_time = True
feature_dim_mask = None
if features == self.get_vars_names(category="state"):
if (
features == self.get_vars_names(category="state")
and split is not None
):
filename_format = STATE_FILENAME_FORMAT
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
# only select one member for now
Expand All @@ -388,12 +394,14 @@ def _get_single_timeseries_dataarray(
len(features) + n_to_drop, dtype=bool
)
feature_dim_mask[self._remove_state_features_with_index] = False
elif features == ["toa_downwelling_shortwave_flux"]:
elif (
features == ["toa_downwelling_shortwave_flux"] and split is not None
):
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
add_feature_dim = True
fp_samples = self.root_path / "samples" / split
elif features == ["open_water_fraction"]:
elif features == ["open_water_fraction"] and split is not None:
filename_format = OPEN_WATER_FILENAME_FORMAT
file_dims = ["y", "x", "feature"]
add_feature_dim = True
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/datastore/plot_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _parse_dict(arg_str):

selection = dict(args.selection)
index_selection = dict(args.index_selection)
datastore_kind = args.datastore_kind

# check that column dimension is not in the selection
if args.col_dim.format(category=args.category) in selection:
Expand All @@ -173,7 +174,7 @@ def _parse_dict(arg_str):
)

datastore = init_datastore(
datastore_kind=args.datastore_kind,
datastore_kind=datastore_kind,
config_path=args.datastore_config_path,
)

Expand Down
20 changes: 11 additions & 9 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Standard library
import os
import warnings
from typing import List, Union
from typing import Any, Dict, List

# Third-party
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -43,6 +43,8 @@ def __init__(
da_static_features = datastore.get_dataarray(
category="static", split=None, standardize=True
)
if da_static_features is None:
raise ValueError("Static features are required for ARModel")
da_state_stats = datastore.get_standardization_dataarray(
category="state"
)
Expand Down Expand Up @@ -127,10 +129,10 @@ def __init__(
"interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border

self.val_metrics = {
self.val_metrics: Dict[str, List] = {
"mse": [],
}
self.test_metrics = {
self.test_metrics: Dict[str, List] = {
"mse": [],
"mae": [],
}
Expand All @@ -145,12 +147,12 @@ def __init__(
self.plotted_examples = 0

# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []
self.spatial_loss_maps: List[Any] = []

def _create_dataarray_from_tensor(
self,
tensor: torch.Tensor,
time: Union[int, List[int]],
time: torch.Tensor,
split: str,
category: str,
) -> xr.DataArray:
Expand All @@ -166,9 +168,9 @@ def _create_dataarray_from_tensor(
The tensor to convert to a `xr.DataArray` with dimensions [time,
grid_index, feature]. The tensor will be copied to the CPU if it is
not already there.
time : Union[int,List[int]]
The time index or indices for the data, given as integers or a list
of integers representing epoch time in nanoseconds. The ints will be
time : torch.Tensor
The time index or indices for the data, given as tensor representing
epoch time in nanoseconds. The tensor will be
copied to the CPU memory if they are not already there.
split : str
The split of the data, either 'train', 'val', or 'test'
Expand All @@ -181,7 +183,7 @@ def _create_dataarray_from_tensor(
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor.cpu().numpy(), time=time, category=category
tensor=tensor, time=time, category=category
)
return da

Expand Down
6 changes: 3 additions & 3 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
datastore: BaseRegularGridDatastore,
da_prediction: xr.DataArray = None,
da_target: xr.DataArray = None,
da_prediction: xr.DataArray,
da_target: xr.DataArray,
title=None,
vrange=None,
):
Expand All @@ -82,7 +82,7 @@ def plot_prediction(
if vrange is None:
vmin = min(da_prediction.min(), da_target.min())
vmax = max(da_prediction.max(), da_target.max())
else:
elif vrange is not None:
vmin, vmax = vrange

extent = datastore.get_xy_extent("state")
Expand Down
Loading
Loading