From 75d4ddd2f92ea7fd80dfaa36d559bf9d793392e0 Mon Sep 17 00:00:00 2001 From: sadamov Date: Tue, 4 Feb 2025 17:49:19 +0100 Subject: [PATCH] added boundaries --- docs/notebooks/gridded_cosmo.py | 149 +++++++++++++++++++++++++++----- 1 file changed, 126 insertions(+), 23 deletions(-) diff --git a/docs/notebooks/gridded_cosmo.py b/docs/notebooks/gridded_cosmo.py index cd04d5d8..aa959ec1 100644 --- a/docs/notebooks/gridded_cosmo.py +++ b/docs/notebooks/gridded_cosmo.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import xarray as xr +from mllam_data_prep.ops.cropping import crop_with_convex_hull from pysteps.verification.salscores import sal # requires scikit-image from scipy.stats import gaussian_kde, wasserstein_distance from scores.categorical import BinaryContingencyManager @@ -36,12 +37,15 @@ PATH_GROUND_TRUTH = "/capstor/store/cscs/swissai/a01/sadamov/cosmo_sample.zarr" PATH_NWP = "/capstor/store/cscs/swissai/a01/sadamov/cosmo_sample.zarr" PATH_ML = "/capstor/store/cscs/swissai/a01/sadamov/cosmo_sample.zarr" +PATH_BOUNDARY = "/capstor/store/cscs/swissai/a01/sadamov/era5_wb_2015_2020.zarr" # Temporal resolution in hours. TEMPORAL_RESOLUTION = 1 # Selection of time steps (e.g. ["2021-01-01T06", "2021-01-05T00"], or [None, # None]) -DATETIMES = ["2016-01-01T00", "2016-01-10T00"] -# Selection spatial grid in projection +DATETIMES = [ + "2016-01-01T00", + "2016-01-10T00", +] # Selection spatial grid in projection X = [None, None] Y = [None, None] # Selection of vertical levels @@ -66,7 +70,23 @@ "ASOB_S", "ATHB_S", ] -VARIABLES_3D = ["U", "V", "PP", "T", "QV", "W"] +VARIABLES_3D = ["U", "V", "PP", "T", "RELHUM", "W"] +# Variable mapping dictionary boundary/interior +var_mapping = { + # Surface variables + "U_10M": "10m_u_component_of_wind", + "T_2M": "2m_temperature", + "V_10M": "10m_v_component_of_wind", + "PMSL": "mean_sea_level_pressure", + "PS": "surface_pressure", + "TOT_PREC": "total_precipitation", + # Height level variables + "U": "u_component_of_wind", + "V": "v_component_of_wind", + "T": "temperature", + "RELHUM": "specific_humidity", + "W": "vertical_velocity", +} # For some plots a random time step sample is selected RANDOM_SEED = 42 TIME_SUBSAMPLES = 10 @@ -232,6 +252,73 @@ ds_nwp_standardized = (ds_nwp - mean) / std ds_ml_standardized = (ds_ml - mean) / std +# %% +# Get coordinates +if hasattr(ds_gt, "longitude") and hasattr(ds_gt, "latitude"): + lons = ds_gt.longitude.values + lats = ds_gt.latitude.values +elif hasattr(ds_gt, "lon") and hasattr(ds_gt, "lat"): + lons = ds_gt.lon.values + lats = ds_gt.lat.values + +lon_min = lons.min() +lon_max = lons.max() +lat_min = lats.min() +lat_max = lats.max() + +# Transform domain bounds to rotated coordinates +transformer = PROJECTION.transform_points( + ccrs.PlateCarree(), + np.array([lon_min, lon_max]), + np.array([lat_min, lat_max]), +) + +# Get rotated coordinate bounds +rot_lon_min, rot_lon_max = transformer[:, 0].min(), transformer[:, 0].max() +rot_lat_min, rot_lat_max = transformer[:, 1].min(), transformer[:, 1].max() + + +# %% +ds_boundary = xr.open_zarr(PATH_BOUNDARY) +# 1. Normalize longitude to -180 to 180 range +longitude_new = np.where( + ds_boundary["longitude"] > 180, + ds_boundary["longitude"] - 360, + ds_boundary["longitude"], +) +ds = ds_boundary.assign_coords(longitude=longitude_new).sortby([ + "longitude", + "latitude", +]) + +# 2. Transform coordinates to rotated projection +lon_mesh, lat_mesh = np.meshgrid(ds.longitude, ds.latitude) +transformed = PROJECTION.transform_points( + ccrs.PlateCarree(), lon_mesh, lat_mesh +) +ds = ds.assign_coords( + rot_longitude=(["latitude", "longitude"], transformed[..., 0]), + rot_latitude=(["latitude", "longitude"], transformed[..., 1]), +) + +# 3. Create and apply masks for selection +lon_mask = (ds.rot_longitude >= (rot_lon_min - padding_degrees)) & ( + ds.rot_longitude <= (rot_lon_max + padding_degrees) +) +lat_mask = (ds.rot_latitude >= (rot_lat_min - padding_degrees)) & ( + ds.rot_latitude <= (rot_lat_max + padding_degrees) +) +ds_cropped = ds.where(lon_mask & lat_mask, drop=True) + +lon_mesh, lat_mesh = np.meshgrid(ds_cropped.longitude, ds_cropped.latitude) + +# # 4. Convert back to 0-360 range and sort +# longitude_back = (ds_cropped["longitude"] + 360) % 360 +# ds_boundary = ds_cropped.assign_coords(longitude=longitude_back).sortby( +# "longitude" +# ) + + # %% [markdown] # ### 1. Maps @@ -253,33 +340,49 @@ time_selected = ds_gt.time[time_index] print(f"Selected time: {time_selected.values}") -for var in VARIABLES_2D: +for var in [VARIABLES_2D[0]]: fig, axes = plt.subplots( - 1, - 3, - figsize=(21, 4), - dpi=DPI, - subplot_kw={"projection": PROJECTION}, + 1, 3, figsize=(21, 4), dpi=DPI, subplot_kw={"projection": PROJECTION} ) + # Get corresponding ERA5 variable name + era5_var = var_mapping.get(var) + + # Select data first ds_var = ds_gt[var].isel(time=time_index) ds_nwp_var = ds_nwp[var].isel(time=time_index) ds_ml_var = ds_ml[var].isel(time=time_index) - vmin = min( - ds_var.min().values, ds_nwp_var.min().values, ds_ml_var.min().values - ) - vmax = max( - ds_var.max().values, ds_nwp_var.max().values, ds_ml_var.max().values - ) - - # Get coordinates - if hasattr(ds_gt, "longitude") and hasattr(ds_gt, "latitude"): - lons = ds_gt.longitude.values - lats = ds_gt.latitude.values - elif hasattr(ds_gt, "lon") and hasattr(ds_gt, "lat"): - lons = ds_gt.lon.values - lats = ds_gt.lat.values + # Combine all arrays for min/max calculation + arrays_for_minmax = [ds_var.values, ds_nwp_var.values, ds_ml_var.values] + + if era5_var and era5_var in ds_cropped: + ds_cropped_var = ds_cropped[era5_var].isel(time=time_index) + arrays_for_minmax.append(ds_cropped_var.values) + + # Calculate global min/max + combined_array = np.concatenate([ + arr.flatten() for arr in arrays_for_minmax + ]) + vmin = np.nanmin(combined_array) + vmax = np.nanmax(combined_array) + + # Plot boundaries first for all subplots + if era5_var and era5_var in ds_cropped: + ds_cropped_var = ds_cropped[era5_var].isel(time=time_index) + + for ax in axes: + cs = ax.contourf( + lon_mesh, + lat_mesh, + ds_cropped_var.values, + transform=ccrs.PlateCarree(), + cmap="viridis", + vmin=vmin, + vmax=vmax, + alpha=0.5, + levels=20 + ) # Plot ground truth im0 = axes[0].pcolormesh(