From 4f13df76c9a6d52a0130ced0db706e051dda985e Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Thu, 28 Nov 2024 14:30:19 +0000 Subject: [PATCH 01/17] initial commit, add ability to clamp predicted outputs to limits supplied in config --- neural_lam/config.py | 17 ++++ neural_lam/models/base_graph_model.py | 91 ++++++++++++++++++- .../mdp/danra_100m_winds/config.yaml | 6 ++ .../mdp/danra_100m_winds/danra.datastore.yaml | 18 ++++ 4 files changed, 131 insertions(+), 1 deletion(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e09697..3787a466 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -67,6 +67,21 @@ class UniformFeatureWeighting: pass +@dataclasses.dataclass +class OutputClamping: + """ + Configuration for clamping the output of the model. + + Attributes + ---------- + lower : Dict[str, float] + The minimum value to clamp each output feature to. + upper : Dict[str, float] + The maximum value to clamp each output feature to. + """ + + lower: Dict[str, float] = dataclasses.field(default_factory=dict) + upper: Dict[str, float] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class TrainingConfig: @@ -86,6 +101,8 @@ class TrainingConfig: ManualStateFeatureWeighting, UniformFeatureWeighting ] = dataclasses.field(default_factory=UniformFeatureWeighting) + output_clamping: OutputClamping = dataclasses.field(default_factory=OutputClamping) + @dataclasses.dataclass class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 6233b4d1..9f9110dc 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -79,6 +79,90 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): layer_norm=False, ) # No layer norm on this one + # Compute indices and define clamping functions + self.prepare_clamping_parames(config,datastore) + + def prepare_clamping_parames(self,config:NeuralLAMConfig,datastore:BaseDatastore): + """ + Prepare parameters for clamping predicted values to valid range + """ + + # Read configs + state_feature_names = datastore.get_vars_names(category="state") + lower_lims = config.training.output_clamping.lower + upper_lims = config.training.output_clamping.upper + + # Check that limits in config are for valid features + unknown_features_lower = set(lower_lims.keys()) - set(state_feature_names) + unknown_features_upper = set(upper_lims.keys()) - set(state_feature_names) + if unknown_features_lower or unknown_features_upper: + raise ValueError(f"State feature limits were provided for unknown features: {unknown_features_lower.union(unknown_features_upper)}") + + # Constant parameters for clamping + sharpness_sigmoid,softplus_sharpness = 1,1 + sigmoid_center,softplus_center = 0,0 + + normalize_clamping_lim = lambda x,feature_idx: (x-self.state_mean[feature_idx])/self.state_std[feature_idx] + + # Check which clamping functions to use for each feature + sigmoid_lower_upper_idx = [] + sigmoid_lower_lims = [] + sigmoid_upper_lims = [] + + softplus_lower_idx = [] + softplus_lower_lims = [] + + softplus_upper_idx = [] + softplus_upper_lims = [] + + for feature_idx,feature in enumerate(state_feature_names): + if feature in lower_lims and feature in upper_lims: + sigmoid_lower_upper_idx.append(feature_idx) + sigmoid_lower_lims.append(normalize_clamping_lim(lower_lims[feature], feature_idx)) + sigmoid_upper_lims.append(normalize_clamping_lim(upper_lims[feature], feature_idx)) + elif feature in lower_lims and feature not in upper_lims: + softplus_lower_idx.append(feature_idx) + softplus_lower_lims.append(normalize_clamping_lim(lower_lims[feature], feature_idx)) + elif feature not in lower_lims and feature in upper_lims: + softplus_upper_idx.append(feature_idx) + softplus_upper_lims.append(normalize_clamping_lim(upper_lims[feature], feature_idx)) + + # Convert to tensors + self.register_buffer('sigmoid_lower_lims', torch.tensor(sigmoid_lower_lims), persistent=False) + self.register_buffer('sigmoid_upper_lims', torch.tensor(sigmoid_upper_lims), persistent=False) + self.register_buffer('softplus_lower_lims', torch.tensor(softplus_lower_lims), persistent=False) + self.register_buffer('softplus_upper_lims', torch.tensor(softplus_upper_lims), persistent=False) + + self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) + self.clamp_lower_idx = torch.tensor(softplus_lower_idx) + self.clamp_upper_idx = torch.tensor(softplus_upper_idx) + + # Define clamping functions + self.clamp_lower_upper = lambda state: self.sigmoid_lower_lims+(self.sigmoid_upper_lims-self.sigmoid_lower_lims)*torch.sigmoid(sharpness_sigmoid*(state-sigmoid_center)) + self.clamp_lower = lambda state: self.softplus_lower_lims+torch.nn.functional.softplus(state-softplus_center,beta=softplus_sharpness) + self.clamp_upper = lambda state: self.softplus_upper_lims-torch.nn.functional.softplus(state-softplus_center,beta=softplus_sharpness) + + def clamp_prediction(self,state): + """ + Clamp prediction to valid range supplied in config + + state: (B, num_grid_nodes, feature_dim) + """ + + # Sigmoid/logistic clamps between ]a,b[ + if self.clamp_lower_upper_idx.numel() > 0: + state[:,:,self.clamp_lower_upper_idx] = self.clamp_lower_upper(state[:,:,self.clamp_lower_upper_idx]) + + # Softplus clamps between ]a,infty[ + if self.clamp_lower_idx.numel() > 0: + state[:,:,self.clamp_lower_idx] = self.clamp_lower(state[:,:,self.clamp_lower_idx]) + + # Softplus clamps between ]-infty,b[ + if self.clamp_upper_idx.numel() > 0: + state[:,:,self.clamp_upper_idx] = self.clamp_upper(state[:,:,self.clamp_upper_idx]) + + return state + def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, @@ -174,4 +258,9 @@ def predict_step(self, prev_state, prev_prev_state, forcing): rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean # Residual connection for full state - return prev_state + rescaled_delta_mean, pred_std + new_state = prev_state + rescaled_delta_mean + + # Clamp values to valid range + new_state_clamped = self.clamp_prediction(new_state) + + return new_state_clamped, pred_std diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml index 0bb5c5ec..7c998fe9 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -7,3 +7,9 @@ training: weights: u100m: 1.0 v100m: 1.0 + output_clamping: + lower: + t2m: 0.0 + r2m: 0 + upper: + r2m: 100.0 diff --git a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml index 3edf1267..0d159f2d 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml @@ -72,6 +72,24 @@ inputs: method: stack_variables_by_var_name name_format: "{var_name}" target_output_variable: forcing + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + - r2m + - t2m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + state_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: state danra_lsm: path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr From 50ce774299e78ea04ba791795cb890712562bf49 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Fri, 29 Nov 2024 08:57:41 +0000 Subject: [PATCH 02/17] linting --- neural_lam/config.py | 6 +- neural_lam/models/base_graph_model.py | 126 +++++++++++++----- .../mdp/danra_100m_winds/config.yaml | 2 + .../mdp/danra_100m_winds/danra.datastore.yaml | 4 +- 4 files changed, 99 insertions(+), 39 deletions(-) diff --git a/neural_lam/config.py b/neural_lam/config.py index 3787a466..9deb15d6 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -67,6 +67,7 @@ class UniformFeatureWeighting: pass + @dataclasses.dataclass class OutputClamping: """ @@ -83,6 +84,7 @@ class OutputClamping: lower: Dict[str, float] = dataclasses.field(default_factory=dict) upper: Dict[str, float] = dataclasses.field(default_factory=dict) + @dataclasses.dataclass class TrainingConfig: """ @@ -101,7 +103,9 @@ class TrainingConfig: ManualStateFeatureWeighting, UniformFeatureWeighting ] = dataclasses.field(default_factory=UniformFeatureWeighting) - output_clamping: OutputClamping = dataclasses.field(default_factory=OutputClamping) + output_clamping: OutputClamping = dataclasses.field( + default_factory=OutputClamping + ) @dataclasses.dataclass diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 9f9110dc..89186051 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -80,9 +80,11 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): ) # No layer norm on this one # Compute indices and define clamping functions - self.prepare_clamping_parames(config,datastore) + self.prepare_clamping_parames(config, datastore) - def prepare_clamping_parames(self,config:NeuralLAMConfig,datastore:BaseDatastore): + def prepare_clamping_parames( + self, config: NeuralLAMConfig, datastore: BaseDatastore + ): """ Prepare parameters for clamping predicted values to valid range """ @@ -93,16 +95,26 @@ def prepare_clamping_parames(self,config:NeuralLAMConfig,datastore:BaseDatastore upper_lims = config.training.output_clamping.upper # Check that limits in config are for valid features - unknown_features_lower = set(lower_lims.keys()) - set(state_feature_names) - unknown_features_upper = set(upper_lims.keys()) - set(state_feature_names) + unknown_features_lower = set(lower_lims.keys()) - set( + state_feature_names + ) + unknown_features_upper = set(upper_lims.keys()) - set( + state_feature_names + ) if unknown_features_lower or unknown_features_upper: - raise ValueError(f"State feature limits were provided for unknown features: {unknown_features_lower.union(unknown_features_upper)}") + raise ValueError( + "State feature limits were provided for unknown features: " + f"{unknown_features_lower.union(unknown_features_upper)}" + ) # Constant parameters for clamping - sharpness_sigmoid,softplus_sharpness = 1,1 - sigmoid_center,softplus_center = 0,0 - - normalize_clamping_lim = lambda x,feature_idx: (x-self.state_mean[feature_idx])/self.state_std[feature_idx] + sharpness_sigmoid, softplus_sharpness = 1, 1 + sigmoid_center, softplus_center = 0, 0 + + normalize_clamping_lim = ( + lambda x, feature_idx: (x - self.state_mean[feature_idx]) + / self.state_std[feature_idx] + ) # Check which clamping functions to use for each feature sigmoid_lower_upper_idx = [] @@ -114,55 +126,97 @@ def prepare_clamping_parames(self,config:NeuralLAMConfig,datastore:BaseDatastore softplus_upper_idx = [] softplus_upper_lims = [] - - for feature_idx,feature in enumerate(state_feature_names): + + for feature_idx, feature in enumerate(state_feature_names): if feature in lower_lims and feature in upper_lims: sigmoid_lower_upper_idx.append(feature_idx) - sigmoid_lower_lims.append(normalize_clamping_lim(lower_lims[feature], feature_idx)) - sigmoid_upper_lims.append(normalize_clamping_lim(upper_lims[feature], feature_idx)) - elif feature in lower_lims and feature not in upper_lims: + sigmoid_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) + sigmoid_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + elif feature in lower_lims and feature not in upper_lims: softplus_lower_idx.append(feature_idx) - softplus_lower_lims.append(normalize_clamping_lim(lower_lims[feature], feature_idx)) + softplus_lower_lims.append( + normalize_clamping_lim(lower_lims[feature], feature_idx) + ) elif feature not in lower_lims and feature in upper_lims: softplus_upper_idx.append(feature_idx) - softplus_upper_lims.append(normalize_clamping_lim(upper_lims[feature], feature_idx)) - + softplus_upper_lims.append( + normalize_clamping_lim(upper_lims[feature], feature_idx) + ) + # Convert to tensors - self.register_buffer('sigmoid_lower_lims', torch.tensor(sigmoid_lower_lims), persistent=False) - self.register_buffer('sigmoid_upper_lims', torch.tensor(sigmoid_upper_lims), persistent=False) - self.register_buffer('softplus_lower_lims', torch.tensor(softplus_lower_lims), persistent=False) - self.register_buffer('softplus_upper_lims', torch.tensor(softplus_upper_lims), persistent=False) + self.register_buffer( + "sigmoid_lower_lims", + torch.tensor(sigmoid_lower_lims), + persistent=False, + ) + self.register_buffer( + "sigmoid_upper_lims", + torch.tensor(sigmoid_upper_lims), + persistent=False, + ) + self.register_buffer( + "softplus_lower_lims", + torch.tensor(softplus_lower_lims), + persistent=False, + ) + self.register_buffer( + "softplus_upper_lims", + torch.tensor(softplus_upper_lims), + persistent=False, + ) self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) self.clamp_lower_idx = torch.tensor(softplus_lower_idx) self.clamp_upper_idx = torch.tensor(softplus_upper_idx) - + # Define clamping functions - self.clamp_lower_upper = lambda state: self.sigmoid_lower_lims+(self.sigmoid_upper_lims-self.sigmoid_lower_lims)*torch.sigmoid(sharpness_sigmoid*(state-sigmoid_center)) - self.clamp_lower = lambda state: self.softplus_lower_lims+torch.nn.functional.softplus(state-softplus_center,beta=softplus_sharpness) - self.clamp_upper = lambda state: self.softplus_upper_lims-torch.nn.functional.softplus(state-softplus_center,beta=softplus_sharpness) + self.clamp_lower_upper = lambda state: self.sigmoid_lower_lims + ( + self.sigmoid_upper_lims - self.sigmoid_lower_lims + ) * torch.sigmoid(sharpness_sigmoid * (state - sigmoid_center)) + self.clamp_lower = ( + lambda state: self.softplus_lower_lims + + torch.nn.functional.softplus( + state - softplus_center, beta=softplus_sharpness + ) + ) + self.clamp_upper = ( + lambda state: self.softplus_upper_lims + - torch.nn.functional.softplus( + state - softplus_center, beta=softplus_sharpness + ) + ) - def clamp_prediction(self,state): + def clamp_prediction(self, state): """ Clamp prediction to valid range supplied in config state: (B, num_grid_nodes, feature_dim) """ - + # Sigmoid/logistic clamps between ]a,b[ if self.clamp_lower_upper_idx.numel() > 0: - state[:,:,self.clamp_lower_upper_idx] = self.clamp_lower_upper(state[:,:,self.clamp_lower_upper_idx]) - + state[:, :, self.clamp_lower_upper_idx] = self.clamp_lower_upper( + state[:, :, self.clamp_lower_upper_idx] + ) + # Softplus clamps between ]a,infty[ if self.clamp_lower_idx.numel() > 0: - state[:,:,self.clamp_lower_idx] = self.clamp_lower(state[:,:,self.clamp_lower_idx]) - + state[:, :, self.clamp_lower_idx] = self.clamp_lower( + state[:, :, self.clamp_lower_idx] + ) + # Softplus clamps between ]-infty,b[ if self.clamp_upper_idx.numel() > 0: - state[:,:,self.clamp_upper_idx] = self.clamp_upper(state[:,:,self.clamp_upper_idx]) - + state[:, :, self.clamp_upper_idx] = self.clamp_upper( + state[:, :, self.clamp_upper_idx] + ) + return state - + def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, @@ -258,8 +312,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing): rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean # Residual connection for full state - new_state = prev_state + rescaled_delta_mean - + new_state = prev_state + rescaled_delta_mean + # Clamp values to valid range new_state_clamped = self.clamp_prediction(new_state) diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml index 7c998fe9..a57266f4 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -7,6 +7,8 @@ training: weights: u100m: 1.0 v100m: 1.0 + t2m: 1.0 + r2m: 1.0 output_clamping: lower: t2m: 0.0 diff --git a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml index 0d159f2d..e601cc02 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml @@ -55,7 +55,7 @@ inputs: dims: [x, y] target_output_variable: state - danra_surface: + danra_surface_forcing: path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr dims: [time, x, y] variables: @@ -72,7 +72,7 @@ inputs: method: stack_variables_by_var_name name_format: "{var_name}" target_output_variable: forcing - + danra_surface: path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr dims: [time, x, y] From a247f9a849f02c2835761b3142cd29478beba269 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Fri, 29 Nov 2024 10:19:17 +0000 Subject: [PATCH 03/17] ensure only state delta is clamped but enforcing limits on the final outputted state --- neural_lam/models/base_graph_model.py | 63 ++++++++++++++++----------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 89186051..bab06b6e 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -108,8 +108,8 @@ def prepare_clamping_parames( ) # Constant parameters for clamping - sharpness_sigmoid, softplus_sharpness = 1, 1 - sigmoid_center, softplus_center = 0, 0 + sharpness_sigmoid = softplus_sharpness = 1 + sigmoid_center = softplus_center = 0 normalize_clamping_lim = ( lambda x, feature_idx: (x - self.state_mean[feature_idx]) @@ -174,48 +174,61 @@ def prepare_clamping_parames( self.clamp_upper_idx = torch.tensor(softplus_upper_idx) # Define clamping functions - self.clamp_lower_upper = lambda state: self.sigmoid_lower_lims + ( - self.sigmoid_upper_lims - self.sigmoid_lower_lims - ) * torch.sigmoid(sharpness_sigmoid * (state - sigmoid_center)) - self.clamp_lower = ( - lambda state: self.softplus_lower_lims + self.clamp_lower_upper = lambda delta, lower, upper: ( + lower + + (upper - lower) + * torch.sigmoid(sharpness_sigmoid * (delta - sigmoid_center)) + ) + self.clamp_lower = lambda delta, lower: ( + lower + torch.nn.functional.softplus( - state - softplus_center, beta=softplus_sharpness + delta - softplus_center, beta=softplus_sharpness ) ) - self.clamp_upper = ( - lambda state: self.softplus_upper_lims + self.clamp_upper = lambda delta, upper: ( + upper - torch.nn.functional.softplus( - state - softplus_center, beta=softplus_sharpness + delta - softplus_center, beta=softplus_sharpness ) ) - def clamp_prediction(self, state): + def clamp_prediction(self, state_delta, prev_state): """ Clamp prediction to valid range supplied in config - state: (B, num_grid_nodes, feature_dim) + state_delta: (B, num_grid_nodes, feature_dim) + prev_state: (B, num_grid_nodes, feature_dim) """ # Sigmoid/logistic clamps between ]a,b[ if self.clamp_lower_upper_idx.numel() > 0: - state[:, :, self.clamp_lower_upper_idx] = self.clamp_lower_upper( - state[:, :, self.clamp_lower_upper_idx] + idx = self.clamp_lower_upper_idx + + state_delta[:, :, idx] = self.clamp_lower_upper( + state_delta[:, :, idx], + self.sigmoid_lower_lims - prev_state[:, :, idx], + self.sigmoid_upper_lims - prev_state[:, :, idx], ) # Softplus clamps between ]a,infty[ if self.clamp_lower_idx.numel() > 0: - state[:, :, self.clamp_lower_idx] = self.clamp_lower( - state[:, :, self.clamp_lower_idx] + idx = self.clamp_lower_idx + + state_delta[:, :, idx] = self.clamp_lower( + state_delta[:, :, idx], + self.softplus_lower_lims - prev_state[:, :, idx], ) # Softplus clamps between ]-infty,b[ if self.clamp_upper_idx.numel() > 0: - state[:, :, self.clamp_upper_idx] = self.clamp_upper( - state[:, :, self.clamp_upper_idx] + idx = self.clamp_upper_idx + + state_delta[:, :, idx] = self.clamp_upper( + state_delta[:, :, idx], + self.softplus_upper_lims - prev_state[:, :, idx], ) - return state + return state_delta def get_num_mesh(self): """ @@ -311,10 +324,10 @@ def predict_step(self, prev_state, prev_prev_state, forcing): # Rescale with one-step difference statistics rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean - # Residual connection for full state - new_state = prev_state + rescaled_delta_mean - # Clamp values to valid range - new_state_clamped = self.clamp_prediction(new_state) + delta_clamped = self.clamp_prediction(rescaled_delta_mean, prev_state) + + # Residual connection for full state + new_state = prev_state + delta_clamped - return new_state_clamped, pred_std + return new_state, pred_std From 4a27c85723ca4d22306f52808bfe5362a64d8b8c Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Wed, 4 Dec 2024 10:26:57 +0000 Subject: [PATCH 04/17] update clamping method to use inverse method --- neural_lam/models/base_graph_model.py | 126 ++++++++++++++++---------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index bab06b6e..05c8b176 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -108,7 +108,7 @@ def prepare_clamping_parames( ) # Constant parameters for clamping - sharpness_sigmoid = softplus_sharpness = 1 + sigmoid_sharpness = softplus_sharpness = 1 sigmoid_center = softplus_center = 0 normalize_clamping_lim = ( @@ -148,87 +148,122 @@ def prepare_clamping_parames( ) # Convert to tensors - self.register_buffer( - "sigmoid_lower_lims", - torch.tensor(sigmoid_lower_lims), - persistent=False, - ) - self.register_buffer( - "sigmoid_upper_lims", - torch.tensor(sigmoid_upper_lims), - persistent=False, - ) - self.register_buffer( - "softplus_lower_lims", - torch.tensor(softplus_lower_lims), - persistent=False, - ) - self.register_buffer( - "softplus_upper_lims", - torch.tensor(softplus_upper_lims), - persistent=False, - ) + # self.register_buffer( + # "sigmoid_lower_lims", + # torch.tensor(sigmoid_lower_lims), + # persistent=False, + # ) + # self.register_buffer( + # "sigmoid_upper_lims", + # torch.tensor(sigmoid_upper_lims), + # persistent=False, + # ) + # self.register_buffer( + # "softplus_lower_lims", + # torch.tensor(softplus_lower_lims), + # persistent=False, + # ) + # self.register_buffer( + # "softplus_upper_lims", + # torch.tensor(softplus_upper_lims), + # persistent=False, + # ) + sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) + sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) + softplus_lower_lims = torch.tensor(softplus_lower_lims) + softplus_upper_lims = torch.tensor(softplus_upper_lims) self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) self.clamp_lower_idx = torch.tensor(softplus_lower_idx) self.clamp_upper_idx = torch.tensor(softplus_upper_idx) # Define clamping functions - self.clamp_lower_upper = lambda delta, lower, upper: ( - lower - + (upper - lower) - * torch.sigmoid(sharpness_sigmoid * (delta - sigmoid_center)) + self.clamp_lower_upper = lambda x: ( + sigmoid_lower_lims + + (sigmoid_upper_lims - sigmoid_lower_lims) + * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) ) - self.clamp_lower = lambda delta, lower: ( - lower + self.clamp_lower = lambda x: ( + softplus_lower_lims + torch.nn.functional.softplus( - delta - softplus_center, beta=softplus_sharpness + x - softplus_center, beta=softplus_sharpness ) ) - self.clamp_upper = lambda delta, upper: ( - upper + self.clamp_upper = lambda x: ( + softplus_upper_lims - torch.nn.functional.softplus( - delta - softplus_center, beta=softplus_sharpness + softplus_center - x, beta=softplus_sharpness ) ) + # Define inverse clamping functions + def inverse_softplus(x, beta=1, threshold=20): + # If x*beta is above threshold, returns linear function + # for numerical stability + under_lim = x * beta <= threshold + x[under_lim] = torch.log(torch.expm1(x[under_lim] * beta)) / beta + return x + + def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + + self.inverse_clamp_lower_upper = lambda x: ( + sigmoid_center + + inverse_sigmoid( + (x - sigmoid_lower_lims) + / (sigmoid_upper_lims - sigmoid_lower_lims) + ) + / sigmoid_sharpness + ) + self.inverse_clamp_lower = lambda x: ( + inverse_softplus(x - softplus_lower_lims, beta=softplus_sharpness) + + softplus_center + ) + self.inverse_clamp_upper = lambda x: ( + -inverse_softplus(softplus_upper_lims - x, beta=softplus_sharpness) + + softplus_center + ) + def clamp_prediction(self, state_delta, prev_state): """ Clamp prediction to valid range supplied in config + Returns the clamped new state after adding delta to original state state_delta: (B, num_grid_nodes, feature_dim) prev_state: (B, num_grid_nodes, feature_dim) """ + # Assign new state, but overwrite clamped values of each type later + new_state = prev_state + state_delta + # Sigmoid/logistic clamps between ]a,b[ if self.clamp_lower_upper_idx.numel() > 0: idx = self.clamp_lower_upper_idx - state_delta[:, :, idx] = self.clamp_lower_upper( - state_delta[:, :, idx], - self.sigmoid_lower_lims - prev_state[:, :, idx], - self.sigmoid_upper_lims - prev_state[:, :, idx], + new_state[:, :, idx] = self.clamp_lower_upper( + self.inverse_clamp_lower_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] ) # Softplus clamps between ]a,infty[ if self.clamp_lower_idx.numel() > 0: idx = self.clamp_lower_idx - state_delta[:, :, idx] = self.clamp_lower( - state_delta[:, :, idx], - self.softplus_lower_lims - prev_state[:, :, idx], + new_state[:, :, idx] = self.clamp_lower( + self.inverse_clamp_lower(prev_state[:, :, idx]) + + state_delta[:, :, idx] ) # Softplus clamps between ]-infty,b[ if self.clamp_upper_idx.numel() > 0: idx = self.clamp_upper_idx - state_delta[:, :, idx] = self.clamp_upper( - state_delta[:, :, idx], - self.softplus_upper_lims - prev_state[:, :, idx], + new_state[:, :, idx] = self.clamp_upper( + self.inverse_clamp_upper(prev_state[:, :, idx]) + + state_delta[:, :, idx] ) - return state_delta + return new_state def get_num_mesh(self): """ @@ -324,10 +359,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing): # Rescale with one-step difference statistics rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean - # Clamp values to valid range - delta_clamped = self.clamp_prediction(rescaled_delta_mean, prev_state) - - # Residual connection for full state - new_state = prev_state + delta_clamped + # Clamp values to valid range (also add the delta to the previous state) + new_state = self.clamp_prediction(rescaled_delta_mean, prev_state) return new_state, pred_std From da1480c30403a3fe58c0597e997fb52ff9f25105 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Wed, 4 Dec 2024 10:40:03 +0000 Subject: [PATCH 05/17] prevent inverse sigmoid and softplus from returning +/- inf --- neural_lam/models/base_graph_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 05c8b176..342a570c 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -201,11 +201,17 @@ def inverse_softplus(x, beta=1, threshold=20): # If x*beta is above threshold, returns linear function # for numerical stability under_lim = x * beta <= threshold - x[under_lim] = torch.log(torch.expm1(x[under_lim] * beta)) / beta + x[under_lim] = ( + torch.log( + torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6) + ) + / beta + ) return x def inverse_sigmoid(x): - return torch.log(x / (1 - x)) + x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) + return torch.log(x_clamped / (1 - x_clamped)) self.inverse_clamp_lower_upper = lambda x: ( sigmoid_center From 5c7567d096e1d0068c7f44325fcc9d1793dc4e9a Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Wed, 11 Dec 2024 13:41:10 +0000 Subject: [PATCH 06/17] added test --- neural_lam/models/base_graph_model.py | 84 +++--- .../mdp/danra_100m_winds/config.yaml | 1 + tests/test_clamping.py | 283 ++++++++++++++++++ 3 files changed, 320 insertions(+), 48 deletions(-) create mode 100644 tests/test_clamping.py diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 342a570c..212aefc7 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -108,8 +108,10 @@ def prepare_clamping_parames( ) # Constant parameters for clamping - sigmoid_sharpness = softplus_sharpness = 1 - sigmoid_center = softplus_center = 0 + self.sigmoid_sharpness = 1 + self.softplus_sharpness = 1 + self.sigmoid_center = 0 + self.softplus_center = 0 normalize_clamping_lim = ( lambda x, feature_idx: (x - self.state_mean[feature_idx]) @@ -129,6 +131,11 @@ def prepare_clamping_parames( for feature_idx, feature in enumerate(state_feature_names): if feature in lower_lims and feature in upper_lims: + assert ( + lower_lims[feature] < upper_lims[feature] + ), f'Invalid clamping limits for feature "{feature}",\ + lower: {lower_lims[feature]}, larger than\ + upper: {upper_lims[feature]}' sigmoid_lower_upper_idx.append(feature_idx) sigmoid_lower_lims.append( normalize_clamping_lim(lower_lims[feature], feature_idx) @@ -147,31 +154,10 @@ def prepare_clamping_parames( normalize_clamping_lim(upper_lims[feature], feature_idx) ) - # Convert to tensors - # self.register_buffer( - # "sigmoid_lower_lims", - # torch.tensor(sigmoid_lower_lims), - # persistent=False, - # ) - # self.register_buffer( - # "sigmoid_upper_lims", - # torch.tensor(sigmoid_upper_lims), - # persistent=False, - # ) - # self.register_buffer( - # "softplus_lower_lims", - # torch.tensor(softplus_lower_lims), - # persistent=False, - # ) - # self.register_buffer( - # "softplus_upper_lims", - # torch.tensor(softplus_upper_lims), - # persistent=False, - # ) - sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) - sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) - softplus_lower_lims = torch.tensor(softplus_lower_lims) - softplus_upper_lims = torch.tensor(softplus_upper_lims) + self.sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) + self.sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) + self.softplus_lower_lims = torch.tensor(softplus_lower_lims) + self.softplus_upper_lims = torch.tensor(softplus_upper_lims) self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) self.clamp_lower_idx = torch.tensor(softplus_lower_idx) @@ -179,20 +165,20 @@ def prepare_clamping_parames( # Define clamping functions self.clamp_lower_upper = lambda x: ( - sigmoid_lower_lims - + (sigmoid_upper_lims - sigmoid_lower_lims) - * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) + self.sigmoid_lower_lims + + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + * torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center)) ) self.clamp_lower = lambda x: ( - softplus_lower_lims + self.softplus_lower_lims + torch.nn.functional.softplus( - x - softplus_center, beta=softplus_sharpness + x - self.softplus_center, beta=self.softplus_sharpness ) ) self.clamp_upper = lambda x: ( - softplus_upper_lims + self.softplus_upper_lims - torch.nn.functional.softplus( - softplus_center - x, beta=softplus_sharpness + self.softplus_center - x, beta=self.softplus_sharpness ) ) @@ -200,13 +186,11 @@ def prepare_clamping_parames( def inverse_softplus(x, beta=1, threshold=20): # If x*beta is above threshold, returns linear function # for numerical stability - under_lim = x * beta <= threshold - x[under_lim] = ( - torch.log( - torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6) - ) - / beta + non_linear_part = ( + torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta ) + x = torch.where(x * beta <= threshold, non_linear_part, x) + return x def inverse_sigmoid(x): @@ -214,20 +198,24 @@ def inverse_sigmoid(x): return torch.log(x_clamped / (1 - x_clamped)) self.inverse_clamp_lower_upper = lambda x: ( - sigmoid_center + self.sigmoid_center + inverse_sigmoid( - (x - sigmoid_lower_lims) - / (sigmoid_upper_lims - sigmoid_lower_lims) + (x - self.sigmoid_lower_lims) + / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) ) - / sigmoid_sharpness + / self.sigmoid_sharpness ) self.inverse_clamp_lower = lambda x: ( - inverse_softplus(x - softplus_lower_lims, beta=softplus_sharpness) - + softplus_center + inverse_softplus( + x - self.softplus_lower_lims, beta=self.softplus_sharpness + ) + + self.softplus_center ) self.inverse_clamp_upper = lambda x: ( - -inverse_softplus(softplus_upper_lims - x, beta=softplus_sharpness) - + softplus_center + -inverse_softplus( + self.softplus_upper_lims - x, beta=self.softplus_sharpness + ) + + self.softplus_center ) def clamp_prediction(self, state_delta, prev_state): diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml index a57266f4..d311c121 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -15,3 +15,4 @@ training: r2m: 0 upper: r2m: 100.0 + u100m: 100.0 diff --git a/tests/test_clamping.py b/tests/test_clamping.py new file mode 100644 index 00000000..457be631 --- /dev/null +++ b/tests/test_clamping.py @@ -0,0 +1,283 @@ +# Standard library +from pathlib import Path + +# Third-party +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.datastore.mdp import MDPDatastore +from neural_lam.models.graph_lam import GraphLAM +from tests.conftest import init_datastore_example + + +def test_clamping(): + datastore = init_datastore_example(MDPDatastore.SHORT_NAME) + + graph_name = "1level" + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + graph = graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 2 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1, 3] + metrics_watch = [] + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + model_args = ModelArgs() + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ), + training=nlconfig.TrainingConfig( + output_clamping=nlconfig.OutputClamping( + lower={"t2m": 0.0, "r2m": 0.0}, + upper={"r2m": 100.0, "u100m": 100.0}, + ) + ), + ) + + model = GraphLAM( + args=model_args, + datastore=datastore, + config=config, + ) + + features = datastore.get_vars_names(category="state") + original_state = torch.zeros(1, 1, len(features)) + zero_delta = original_state.clone() + + # Get a state well within the bounds + original_state[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_lower_lims + model.sigmoid_upper_lims + ) / 2 + original_state[:, :, model.clamp_lower_idx] = model.softplus_lower_lims + 10 + original_state[:, :, model.clamp_upper_idx] = model.softplus_upper_lims - 10 + + # Get a delta that tries to push the state out of bounds + delta = torch.ones_like(zero_delta) + delta[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_upper_lims - model.sigmoid_lower_lims + ) / 3 + delta[:, :, model.clamp_lower_idx] = -5 + delta[:, :, model.clamp_upper_idx] = 5 + + # Check that a delta of 0 gives unchanged state + zero_prediction = model.clamp_prediction(zero_delta, original_state) + assert (abs(original_state - zero_prediction) < 1e-6).all().item() + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.clamp_prediction(delta, prediction) + + # check that unclamped states are as expected + # delta is 1, so they should be 1*n_loops + assert ( + ( + abs( + prediction[ + :, + :, + list( + set(range(len(features))) + - set(model.clamp_lower_upper_idx.tolist()) + - set(model.clamp_lower_idx.tolist()) + - set(model.clamp_upper_idx.tolist()) + ), + ] + - n_loops + ) + < 1e-6 + ) + .all() + .item() + ) + + # Check that clamped states are within bounds + # they should not be at the bounds but allow it due to numerical precision + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + (model.softplus_lower_lims <= prediction[:, :, model.clamp_lower_idx]) + .all() + .item() + ) + assert ( + (prediction[:, :, model.clamp_upper_idx] <= model.softplus_upper_lims) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + unscaled_prediction = prediction * model.state_std + model.state_mean + features_idx = {f: i for i, f in enumerate(features)} + lower_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.lower.items() + } + upper_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.upper.items() + } + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state + 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + assert ( + not ( + model.softplus_lower_lims + <= invalid_state[:, :, model.clamp_lower_idx] + ) + .any() + .item() + ) + assert ( + not ( + invalid_state[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + ( + model.softplus_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_idx] + ) + .all() + .item() + ) + assert ( + ( + invalid_prediction[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .all() + .item() + ) + + # Above tests only check the upper sigmoid limit. + # Repeat to check lower sigmoid limit + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.clamp_prediction(-delta, prediction) + + # Check that clamped states are within bounds + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state - 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) From 3bc51ab8414eb7dcd18e0d98fd2faa05a69efa0f Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Mon, 16 Dec 2024 09:50:30 +0100 Subject: [PATCH 07/17] fix typo changed prepare_clamping_parames to prepare_clamping_params --- neural_lam/models/base_graph_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 212aefc7..bddd0ebc 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -80,9 +80,9 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): ) # No layer norm on this one # Compute indices and define clamping functions - self.prepare_clamping_parames(config, datastore) + self.prepare_clamping_params(config, datastore) - def prepare_clamping_parames( + def prepare_clamping_params( self, config: NeuralLAMConfig, datastore: BaseDatastore ): """ From 0681a3559dbe2c7978c7d30215a38a2027a340b2 Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Mon, 16 Dec 2024 13:57:18 +0100 Subject: [PATCH 08/17] Update README.md Added description of clamping feature in config.yaml --- README.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e21b7c24..51f3f375 100644 --- a/README.md +++ b/README.md @@ -148,12 +148,22 @@ training: values: u100m: 1.0 v100m: 1.0 + t2m: 1.0 + r2m: 1.0 + output_clamping: + lower: + t2m: 0.0 + r2m: 0 + upper: + r2m: 100.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines few things: 1) the kind of data +store and the path to its config, 2) the weighting of different features in +the loss function, and 3) valid numerical range for output of each feature. +If you don't define the state feature weighting it will default to +weighting all features equally. The numerical range of all features default +to $]-\infty, \infty[$. (This example is taken from the `tests/datastore_examples/mdp` directory.) From 82844e3b8498130faa738bd0025e707df08cab35 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Tue, 17 Dec 2024 08:17:20 +0000 Subject: [PATCH 09/17] linting --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 51f3f375..f1bcd8d7 100644 --- a/README.md +++ b/README.md @@ -160,8 +160,8 @@ training: For now the neural-lam config only defines few things: 1) the kind of data store and the path to its config, 2) the weighting of different features in -the loss function, and 3) valid numerical range for output of each feature. -If you don't define the state feature weighting it will default to +the loss function, and 3) valid numerical range for output of each feature. +If you don't define the state feature weighting it will default to weighting all features equally. The numerical range of all features default to $]-\infty, \infty[$. From f53ae59b8198421dacca88030fc14f68b984ed41 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Fri, 10 Jan 2025 13:07:52 +0000 Subject: [PATCH 10/17] update docstring and name of clamping function --- neural_lam/models/base_graph_model.py | 11 +++++++++-- tests/test_clamping.py | 10 +++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index bddd0ebc..131a40d5 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -218,11 +218,18 @@ def inverse_sigmoid(x): + self.softplus_center ) - def clamp_prediction(self, state_delta, prev_state): + def get_clamped_new_state(self, state_delta, prev_state): """ Clamp prediction to valid range supplied in config Returns the clamped new state after adding delta to original state + Instead of the new state being computed as + $X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$ + The clamped values will be + $f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$ + Which means the model will learn to output values in the range of the + inverse clamping function + state_delta: (B, num_grid_nodes, feature_dim) prev_state: (B, num_grid_nodes, feature_dim) """ @@ -354,6 +361,6 @@ def predict_step(self, prev_state, prev_prev_state, forcing): rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean # Clamp values to valid range (also add the delta to the previous state) - new_state = self.clamp_prediction(rescaled_delta_mean, prev_state) + new_state = self.get_clamped_new_state(rescaled_delta_mean, prev_state) return new_state, pred_std diff --git a/tests/test_clamping.py b/tests/test_clamping.py index 457be631..d197766e 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -82,14 +82,14 @@ class ModelArgs: delta[:, :, model.clamp_upper_idx] = 5 # Check that a delta of 0 gives unchanged state - zero_prediction = model.clamp_prediction(zero_delta, original_state) + zero_prediction = model.get_clamped_new_state(zero_delta, original_state) assert (abs(original_state - zero_prediction) < 1e-6).all().item() # Make predictions towards bounds for each feature prediction = zero_prediction.clone() n_loops = 100 for i in range(n_loops): - prediction = model.clamp_prediction(delta, prediction) + prediction = model.get_clamped_new_state(delta, prediction) # check that unclamped states are as expected # delta is 1, so they should be 1*n_loops @@ -193,7 +193,7 @@ class ModelArgs: .any() .item() ) - invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + invalid_prediction = model.get_clamped_new_state(zero_delta, invalid_state) assert ( ( model.sigmoid_lower_lims @@ -227,7 +227,7 @@ class ModelArgs: prediction = zero_prediction.clone() n_loops = 100 for i in range(n_loops): - prediction = model.clamp_prediction(-delta, prediction) + prediction = model.get_clamped_new_state(-delta, prediction) # Check that clamped states are within bounds assert ( @@ -271,7 +271,7 @@ class ModelArgs: .any() .item() ) - invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + invalid_prediction = model.get_clamped_new_state(zero_delta, invalid_state) assert ( ( model.sigmoid_lower_lims From 59793934a59a286b0d97df8468a815849b69cf13 Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Tue, 14 Jan 2025 15:42:24 +0100 Subject: [PATCH 11/17] Update README.md Co-authored-by: Joel Oskarsson --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f1bcd8d7..4960711c 100644 --- a/README.md +++ b/README.md @@ -158,12 +158,13 @@ training: r2m: 100.0 ``` -For now the neural-lam config only defines few things: 1) the kind of data -store and the path to its config, 2) the weighting of different features in -the loss function, and 3) valid numerical range for output of each feature. -If you don't define the state feature weighting it will default to -weighting all features equally. The numerical range of all features default -to $]-\infty, \infty[$. +For now the neural-lam config only defines few things: + +1. The kind of datastore and the path to its config +2. The weighting of different features in +the loss function. If you don't define the state feature weighting it will default to +weighting all features equally. +3. Valid numerical range for output of each feature.The numerical range of all features default to $]-\infty, \infty[$. (This example is taken from the `tests/datastore_examples/mdp` directory.) From 18cb4721ce645415f224c28a75148d7a8bcae747 Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Tue, 14 Jan 2025 15:43:38 +0100 Subject: [PATCH 12/17] Update neural_lam/models/base_graph_model.py Co-authored-by: Joel Oskarsson --- neural_lam/models/base_graph_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 131a40d5..869fdcdd 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -194,6 +194,9 @@ def inverse_softplus(x, beta=1, threshold=20): return x def inverse_sigmoid(x): + # Sigmoid output takes values in [0,1], this makes sure input is just within this interval + # Note that this torch.clamp will make gradients 0, but this is not a problem + # as values of x that are this close to 0 or 1 have gradient 0 anyhow. x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) return torch.log(x_clamped / (1 - x_clamped)) From afeee02daf6bbd7cf53ef516db6ce16db5d31e49 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Tue, 14 Jan 2025 15:16:38 +0000 Subject: [PATCH 13/17] review suggestions --- neural_lam/models/base_graph_model.py | 50 ++++++------------- neural_lam/utils.py | 33 ++++++++++++ .../mdp/danra_100m_winds/config.yaml | 2 +- 3 files changed, 50 insertions(+), 35 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 869fdcdd..f88e3dee 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -108,10 +108,10 @@ def prepare_clamping_params( ) # Constant parameters for clamping - self.sigmoid_sharpness = 1 - self.softplus_sharpness = 1 - self.sigmoid_center = 0 - self.softplus_center = 0 + sigmoid_sharpness = 1 + softplus_sharpness = 1 + sigmoid_center = 0 + softplus_center = 0 normalize_clamping_lim = ( lambda x, feature_idx: (x - self.state_mean[feature_idx]) @@ -167,58 +167,40 @@ def prepare_clamping_params( self.clamp_lower_upper = lambda x: ( self.sigmoid_lower_lims + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) - * torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center)) + * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) ) self.clamp_lower = lambda x: ( self.softplus_lower_lims + torch.nn.functional.softplus( - x - self.softplus_center, beta=self.softplus_sharpness + x - softplus_center, beta=softplus_sharpness ) ) self.clamp_upper = lambda x: ( self.softplus_upper_lims - torch.nn.functional.softplus( - self.softplus_center - x, beta=self.softplus_sharpness + softplus_center - x, beta=softplus_sharpness ) ) - # Define inverse clamping functions - def inverse_softplus(x, beta=1, threshold=20): - # If x*beta is above threshold, returns linear function - # for numerical stability - non_linear_part = ( - torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta - ) - x = torch.where(x * beta <= threshold, non_linear_part, x) - - return x - - def inverse_sigmoid(x): - # Sigmoid output takes values in [0,1], this makes sure input is just within this interval - # Note that this torch.clamp will make gradients 0, but this is not a problem - # as values of x that are this close to 0 or 1 have gradient 0 anyhow. - x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) - return torch.log(x_clamped / (1 - x_clamped)) - self.inverse_clamp_lower_upper = lambda x: ( - self.sigmoid_center - + inverse_sigmoid( + sigmoid_center + + utils.inverse_sigmoid( (x - self.sigmoid_lower_lims) / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) ) - / self.sigmoid_sharpness + / sigmoid_sharpness ) self.inverse_clamp_lower = lambda x: ( - inverse_softplus( - x - self.softplus_lower_lims, beta=self.softplus_sharpness + utils.inverse_softplus( + x - self.softplus_lower_lims, beta=softplus_sharpness ) - + self.softplus_center + + softplus_center ) self.inverse_clamp_upper = lambda x: ( - -inverse_softplus( - self.softplus_upper_lims - x, beta=self.softplus_sharpness + -utils.inverse_softplus( + self.softplus_upper_lims - x, beta=softplus_sharpness ) - + self.softplus_center + + softplus_center ) def get_clamped_new_state(self, state_delta, prev_state): diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 4a0752e4..f1861940 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -241,3 +241,36 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +def inverse_softplus(x, beta=1, threshold=20): + """ + Inverse of torch.nn.functional.softplus + + For x*beta above threshold, returns linear function for numerical + stability. + + Input is clamped to x > ln(1+1e-6)/beta which is approximately positive + values of x. + Note that this torch.clamp_min will make gradients 0, but this is not a + problem as values of x that are this close to 0 have gradients of 0 anyhow. + """ + non_linear_part = ( + torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta + ) + x = torch.where(x * beta <= threshold, non_linear_part, x) + + return x + + +def inverse_sigmoid(x): + """ + Inverse of torch.sigmoid + + Sigmoid output takes values in [0,1], this makes sure input is just within + this interval. + Note that this torch.clamp will make gradients 0, but this is not a problem + as values of x that are this close to 0 or 1 have gradients of 0 anyhow. + """ + x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) + return torch.log(x_clamped / (1 - x_clamped)) diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml index d311c121..8b3362e0 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -14,5 +14,5 @@ training: t2m: 0.0 r2m: 0 upper: - r2m: 100.0 + r2m: 1.0 u100m: 100.0 From bb40e8226a857bf81a76de3d7d88eeb0848c54cd Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Tue, 14 Jan 2025 15:21:30 +0000 Subject: [PATCH 14/17] linting --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4960711c..10b0a375 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ training: r2m: 100.0 ``` -For now the neural-lam config only defines few things: +For now the neural-lam config only defines few things: 1. The kind of datastore and the path to its config 2. The weighting of different features in From f10886f78f66103cd920d58dffc09575f3ec37f5 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Tue, 14 Jan 2025 15:48:02 +0000 Subject: [PATCH 15/17] set clamp lims as buffers. Updated clamping test with correct r2m limit --- neural_lam/models/base_graph_model.py | 30 ++++++++++++++++++++------- tests/test_clamping.py | 2 +- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index f88e3dee..0e61f15f 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -154,14 +154,28 @@ def prepare_clamping_params( normalize_clamping_lim(upper_lims[feature], feature_idx) ) - self.sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) - self.sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) - self.softplus_lower_lims = torch.tensor(softplus_lower_lims) - self.softplus_upper_lims = torch.tensor(softplus_upper_lims) - - self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) - self.clamp_lower_idx = torch.tensor(softplus_lower_idx) - self.clamp_upper_idx = torch.tensor(softplus_upper_idx) + self.register_buffer( + "sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims) + ) + self.register_buffer( + "sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims) + ) + self.register_buffer( + "softplus_lower_lims", torch.tensor(softplus_lower_lims) + ) + self.register_buffer( + "softplus_upper_lims", torch.tensor(softplus_upper_lims) + ) + + self.register_buffer( + "clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx) + ) + self.register_buffer( + "clamp_lower_idx", torch.tensor(softplus_lower_idx) + ) + self.register_buffer( + "clamp_upper_idx", torch.tensor(softplus_upper_idx) + ) # Define clamping functions self.clamp_lower_upper = lambda x: ( diff --git a/tests/test_clamping.py b/tests/test_clamping.py index d197766e..f3f9365d 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -51,7 +51,7 @@ class ModelArgs: training=nlconfig.TrainingConfig( output_clamping=nlconfig.OutputClamping( lower={"t2m": 0.0, "r2m": 0.0}, - upper={"r2m": 100.0, "u100m": 100.0}, + upper={"r2m": 1.0, "u100m": 100.0}, ) ), ) From 964a2361907276c4f5cdd04152145d12b2ef026d Mon Sep 17 00:00:00 2001 From: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> Date: Fri, 17 Jan 2025 08:42:49 +0100 Subject: [PATCH 16/17] Update README.md Co-authored-by: Joel Oskarsson --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 10b0a375..d43b6a4c 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ training: t2m: 0.0 r2m: 0 upper: - r2m: 100.0 + r2m: 1.0 ``` For now the neural-lam config only defines few things: From 9e7a4f0f23f8bf69cde6ae9951168b002125fbe6 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Fri, 17 Jan 2025 07:46:06 +0000 Subject: [PATCH 17/17] update changelong --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32961b16..3b33053c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#66](https://github.com/mllam/neural-lam/pull/66) @leifdenby @sadamov +- Add option to clamp output prediction using limits specified in config file [\#92](https://github.com/mllam/neural-lam/pull/92) @SimonKamuk + ### Fixed - Fix wandb environment variable disabling wandb during tests. Now correctly uses WANDB_MODE=disabled. [\#94](https://github.com/mllam/neural-lam/pull/94) @joeloskarsson