From ca9b4065c437b2a09db18960291db5937fe6f063 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 24 Nov 2024 16:00:41 +0000 Subject: [PATCH] add interpolation --- src/ai_models/__main__.py | 5 ++ src/ai_models/inputs/interpolate.py | 89 ----------------------------- src/ai_models/inputs/opendata.py | 6 +- src/ai_models/model.py | 2 +- src/ai_models/outputs/__init__.py | 50 ++++++++++++++-- 5 files changed, 54 insertions(+), 98 deletions(-) delete mode 100644 src/ai_models/inputs/interpolate.py diff --git a/src/ai_models/__main__.py b/src/ai_models/__main__.py index 37b8ec4..ee484aa 100644 --- a/src/ai_models/__main__.py +++ b/src/ai_models/__main__.py @@ -111,6 +111,11 @@ def _main(argv): choices=sorted(available_outputs()), ) + parser.add_argument( + "--interpolate", + help="Should the results be interpolated", + ) + parser.add_argument( "--date", default="-1", diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py deleted file mode 100644 index d1c0f78..0000000 --- a/src/ai_models/inputs/interpolate.py +++ /dev/null @@ -1,89 +0,0 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import lru_cache - -import earthkit.data as ekd -import earthkit.regrid as ekr -import tqdm -from earthkit.data.core.temporary import temp_file - -LOG = logging.getLogger(__name__) - - -@lru_cache(maxsize=None) -def ll_0p25(): - return dict( - latitudeOfFirstGridPointInDegrees=90, - longitudeOfFirstGridPointInDegrees=0, - latitudeOfLastGridPointInDegrees=-90, - longitudeOfLastGridPointInDegrees=359.75, - iDirectionIncrementInDegrees=0.25, - jDirectionIncrementInDegrees=0.25, - Ni=1440, - Nj=721, - gridType="regular_ll", - ) - - -@lru_cache(maxsize=None) -def gg_n320(): - import eccodes - - sample = None - result = {} - try: - sample = eccodes.codes_new_from_samples("reduced_gg_pl_320_grib2", eccodes.CODES_PRODUCT_GRIB) - - for key in ("N", "Ni", "Nj"): - result[key] = eccodes.codes_get(sample, key) - - for key in ( - "latitudeOfFirstGridPointInDegrees", - "longitudeOfFirstGridPointInDegrees", - "latitudeOfLastGridPointInDegrees", - "longitudeOfLastGridPointInDegrees", - ): - result[key] = eccodes.codes_get_double(sample, key) - - pl = eccodes.codes_get_long_array(sample, "pl") - result["pl"] = pl.tolist() - result["gridType"] = "reduced_gg" - - return result - - finally: - if sample is not None: - eccodes.codes_release(sample) - - -METADATA = {"N320": gg_n320} - - -class Interpolate: - def __init__(self, *, source, target): - self.target = list(target) if isinstance(target, tuple) else target - self.source = list(source) if isinstance(source, tuple) else source - - def __call__(self, ds): - tmp = temp_file() - - out = ekd.new_grib_output(tmp.path) - out2 = ekd.new_grib_output("interpolated.grib2") - for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): - data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.target)) - # result.append(f.clone(values=data, **METADATA[self.target]())) - out2.write(data, template=f, metadata=METADATA[self.target]()) - out.write(data, template=f, metadata=METADATA[self.target]()) - - out.close() - - result = ekd.from_source("file", tmp.path) - result._tmp = tmp - - return result diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 4af9948..629f760 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -14,9 +14,9 @@ from earthkit.data.indexing.fieldlist import FieldArray from multiurl import download +from ..interpolate import Interpolate from .base import RequestBasedInput from .compute import make_z_from_gh -from .interpolate import Interpolate from .recenter import recenter from .transform import NewMetadataField @@ -71,7 +71,7 @@ def _adjust(self, kwargs): if isinstance(grid, list): grid = tuple(grid) - kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid] + kwargs["resol"], source, interp, oversampling, _ = RESOLS[grid] r = dict(**kwargs) r.update(self.owner.retrieve) @@ -80,7 +80,7 @@ def _adjust(self, kwargs): logging.info("Interpolating input data from %s to %s.", source, grid) if oversampling: logging.warning("This will oversample the input data.") - return Interpolate(grid, source, metadata) + return Interpolate(source=source, target=grid) else: return _identity diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 01eda32..5bf23f8 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -131,7 +131,7 @@ def collect_archive_requests(self, written): self.archiving[path].add(handle.as_namespace("mars")) def finalise(self): - self.output.flush() + self.output.close() if self.archive_requests: with open(self.archive_requests, "w") as f: diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 9ecf386..e5d6020 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -21,7 +21,7 @@ class Output: def write(self, *args, **kwargs): pass - def flush(self, *args, **kwargs): + def close(self): pass @@ -104,6 +104,9 @@ def write(self, data, *args, check=False, **kwargs): return handle, path + def close(self): + self.output.close() + class FileOutput(GribOutputBase): def __init__(self, *args, **kwargs): @@ -167,8 +170,8 @@ def write(self, *args, **kwargs): return self.output.write(*args, **kwargs) - def flush(self, *args, **kwargs): - return self.output.flush(*args, **kwargs) + def close(self): + return self.output.close() class NoLabelling: @@ -181,8 +184,40 @@ def write(self, *args, **kwargs): kwargs["deleteLocalDefinition"] = 1 return self.output.write(*args, **kwargs) - def flush(self, *args, **kwargs): - return self.output.flush(*args, **kwargs) + def close(self): + return self.output.close() + + +class InterpolatedOutput: + def __init__(self, owner, output, interpolate, **kwargs): + self.owner = owner + self.output = output + try: + self.target = (float(interpolate), float(interpolate)) + except ValueError: + self.target = interpolate.upper() + + @cached_property + def interpolator(self): + from ..interpolate import Interpolate + + return Interpolate(source=self.owner.grid, target=self.target) + + def write(self, values, template, *args, **kwargs): + + if values is None: + values = template.to_numpy(flatten=True) + # We need to extract a few more metadata from the template + for m in ("date", "time", "step", "param", "paramId", "shortName"): + kwargs[m] = template.metadata(m) + + values, metadata = self.interpolator.interpolate(values) + kwargs.update(metadata) + + return self.output.write(values, template, *args, **kwargs) + + def close(self): + return self.output.close() def get_output(name, owner, *args, **kwargs): @@ -191,6 +226,11 @@ def get_output(name, owner, *args, **kwargs): result = HindcastReLabel(owner, result, **kwargs) if owner.expver is None: result = NoLabelling(owner, result, **kwargs) + + if kwargs.get("interpolate") is not None: + # Interpolate the output + result = InterpolatedOutput(owner, result, **kwargs) + return result