diff --git a/climetlab/readers/__init__.py b/climetlab/readers/__init__.py index af52116b..16b89114 100644 --- a/climetlab/readers/__init__.py +++ b/climetlab/readers/__init__.py @@ -37,6 +37,8 @@ def multi_merge(cls, readers): class MultiReaders: + backend_kwargs = {} + def __init__(self, readers): self.readers = readers @@ -58,6 +60,8 @@ def preprocess(ds): assert options == opts, f"{options} != {opts}" options.update(kwargs) + options.setdefault("backend_kwargs", {}) + options["backend_kwargs"].update(self.backend_kwargs) return xr.open_mfdataset( [r.path for r in self.readers], diff --git a/climetlab/readers/grib.py b/climetlab/readers/grib.py index aa513ead..b3371563 100644 --- a/climetlab/readers/grib.py +++ b/climetlab/readers/grib.py @@ -220,6 +220,7 @@ def __iter__(self): class MultiGribReaders(MultiReaders): engine = "cfgrib" + backend_kwargs = {"squeeze": False} class GRIBReader(Reader): @@ -249,12 +250,14 @@ def __getitem__(self, n): def __len__(self): return len(self._items()) - def to_xarray(self): - import xarray as xr + def to_xarray(self, **kwargs): + # So we use the same code + return MultiGribReaders([self]).to_xarray(**kwargs) + # import xarray as xr - params = self.source.cfgrib_options() - ds = xr.open_dataset(self.path, engine="cfgrib", **params) - return self.source.post_xarray_open_dataset_hook(ds) + # params = self.source.cfgrib_options() + # ds = xr.open_dataset(self.path, engine="cfgrib", **params) + # return self.source.post_xarray_open_dataset_hook(ds) # @dict_args # def sel(self, **kwargs): diff --git a/climetlab/readers/netcdf.py b/climetlab/readers/netcdf.py index 41c629ba..eb85f050 100644 --- a/climetlab/readers/netcdf.py +++ b/climetlab/readers/netcdf.py @@ -312,8 +312,10 @@ def _get_fields(self, ds): # noqa C901 return fields - def to_xarray(self): - return xr.open_dataset(self.path, engine="netcdf4") + def to_xarray(self, **kwargs): + # So we use the same code + return MultiNetcdfReaders([self]).to_xarray(**kwargs) + # return xr.open_dataset(self.path, engine="netcdf4") @classmethod def multi_merge(cls, readers): diff --git a/tests/test_grib.py b/tests/test_grib.py index d70e1a53..09768243 100644 --- a/tests/test_grib.py +++ b/tests/test_grib.py @@ -34,7 +34,7 @@ def test_sel(): s.sel(shortName="2t") -@pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") +# @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") def test_multi(): s1 = load_source( "cds", diff --git a/tests/test_netcdf.py b/tests/test_netcdf.py index 4081fe4b..1b2b15da 100644 --- a/tests/test_netcdf.py +++ b/tests/test_netcdf.py @@ -26,6 +26,7 @@ def test_multi(): date="2021-03-01", format="netcdf", ) + print(s1.to_xarray()) s2 = load_source( "cds", "reanalysis-era5-single-levels", @@ -34,6 +35,7 @@ def test_multi(): date="2021-03-02", format="netcdf", ) + print(s2.to_xarray()) source = load_source("multi", s1, s2) for s in source: