Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functionality to Apply Constraints to Predictions #92

Merged
merged 23 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,22 @@ training:
values:
u100m: 1.0
v100m: 1.0
t2m: 1.0
r2m: 1.0
output_clamping:
lower:
t2m: 0.0
r2m: 0
upper:
r2m: 100.0
```

For now the neural-lam config only defines two things: 1) the kind of data
store and the path to its config, and 2) the weighting of different features in
the loss function. If you don't define the state feature weighting it will default
to weighting all features equally.
For now the neural-lam config only defines few things: 1) the kind of data
store and the path to its config, 2) the weighting of different features in
the loss function, and 3) valid numerical range for output of each feature.
If you don't define the state feature weighting it will default to
weighting all features equally. The numerical range of all features default
to $]-\infty, \infty[$.

(This example is taken from the `tests/datastore_examples/mdp` directory.)

Expand Down
21 changes: 21 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ class UniformFeatureWeighting:
pass


@dataclasses.dataclass
class OutputClamping:
"""
Configuration for clamping the output of the model.

Attributes
----------
lower : Dict[str, float]
The minimum value to clamp each output feature to.
upper : Dict[str, float]
The maximum value to clamp each output feature to.
"""

lower: Dict[str, float] = dataclasses.field(default_factory=dict)
upper: Dict[str, float] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class TrainingConfig:
"""
Expand All @@ -86,6 +103,10 @@ class TrainingConfig:
ManualStateFeatureWeighting, UniformFeatureWeighting
] = dataclasses.field(default_factory=UniformFeatureWeighting)

output_clamping: OutputClamping = dataclasses.field(
default_factory=OutputClamping
)


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
Expand Down
193 changes: 191 additions & 2 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,193 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
layer_norm=False,
) # No layer norm on this one

# Compute indices and define clamping functions
self.prepare_clamping_params(config, datastore)

def prepare_clamping_params(
self, config: NeuralLAMConfig, datastore: BaseDatastore
):
"""
Prepare parameters for clamping predicted values to valid range
"""

# Read configs
state_feature_names = datastore.get_vars_names(category="state")
lower_lims = config.training.output_clamping.lower
upper_lims = config.training.output_clamping.upper

# Check that limits in config are for valid features
unknown_features_lower = set(lower_lims.keys()) - set(
state_feature_names
)
unknown_features_upper = set(upper_lims.keys()) - set(
state_feature_names
)
if unknown_features_lower or unknown_features_upper:
raise ValueError(
"State feature limits were provided for unknown features: "
f"{unknown_features_lower.union(unknown_features_upper)}"
)

# Constant parameters for clamping
self.sigmoid_sharpness = 1
self.softplus_sharpness = 1
self.sigmoid_center = 0
self.softplus_center = 0

normalize_clamping_lim = (
lambda x, feature_idx: (x - self.state_mean[feature_idx])
/ self.state_std[feature_idx]
)

# Check which clamping functions to use for each feature
sigmoid_lower_upper_idx = []
sigmoid_lower_lims = []
sigmoid_upper_lims = []

softplus_lower_idx = []
softplus_lower_lims = []

softplus_upper_idx = []
softplus_upper_lims = []

for feature_idx, feature in enumerate(state_feature_names):
if feature in lower_lims and feature in upper_lims:
assert (
lower_lims[feature] < upper_lims[feature]
), f'Invalid clamping limits for feature "{feature}",\
lower: {lower_lims[feature]}, larger than\
upper: {upper_lims[feature]}'
sigmoid_lower_upper_idx.append(feature_idx)
sigmoid_lower_lims.append(
normalize_clamping_lim(lower_lims[feature], feature_idx)
)
sigmoid_upper_lims.append(
normalize_clamping_lim(upper_lims[feature], feature_idx)
)
elif feature in lower_lims and feature not in upper_lims:
softplus_lower_idx.append(feature_idx)
softplus_lower_lims.append(
normalize_clamping_lim(lower_lims[feature], feature_idx)
)
elif feature not in lower_lims and feature in upper_lims:
softplus_upper_idx.append(feature_idx)
softplus_upper_lims.append(
normalize_clamping_lim(upper_lims[feature], feature_idx)
)

self.sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims)
self.sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims)
self.softplus_lower_lims = torch.tensor(softplus_lower_lims)
self.softplus_upper_lims = torch.tensor(softplus_upper_lims)

self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx)
self.clamp_lower_idx = torch.tensor(softplus_lower_idx)
self.clamp_upper_idx = torch.tensor(softplus_upper_idx)

# Define clamping functions
self.clamp_lower_upper = lambda x: (
self.sigmoid_lower_lims
+ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
* torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center))
)
self.clamp_lower = lambda x: (
self.softplus_lower_lims
+ torch.nn.functional.softplus(
x - self.softplus_center, beta=self.softplus_sharpness
)
)
self.clamp_upper = lambda x: (
self.softplus_upper_lims
- torch.nn.functional.softplus(
self.softplus_center - x, beta=self.softplus_sharpness
)
)

# Define inverse clamping functions
def inverse_softplus(x, beta=1, threshold=20):
# If x*beta is above threshold, returns linear function
# for numerical stability
non_linear_part = (
torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta
)
x = torch.where(x * beta <= threshold, non_linear_part, x)

return x

def inverse_sigmoid(x):
x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6)
return torch.log(x_clamped / (1 - x_clamped))

self.inverse_clamp_lower_upper = lambda x: (
self.sigmoid_center
+ inverse_sigmoid(
(x - self.sigmoid_lower_lims)
/ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
)
/ self.sigmoid_sharpness
)
self.inverse_clamp_lower = lambda x: (
inverse_softplus(
x - self.softplus_lower_lims, beta=self.softplus_sharpness
)
+ self.softplus_center
)
self.inverse_clamp_upper = lambda x: (
-inverse_softplus(
self.softplus_upper_lims - x, beta=self.softplus_sharpness
)
+ self.softplus_center
)

def get_clamped_new_state(self, state_delta, prev_state):
"""
Clamp prediction to valid range supplied in config
Returns the clamped new state after adding delta to original state

Instead of the new state being computed as
$X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$
The clamped values will be
$f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$
Which means the model will learn to output values in the range of the
inverse clamping function

state_delta: (B, num_grid_nodes, feature_dim)
prev_state: (B, num_grid_nodes, feature_dim)
"""

# Assign new state, but overwrite clamped values of each type later
new_state = prev_state + state_delta

# Sigmoid/logistic clamps between ]a,b[
if self.clamp_lower_upper_idx.numel() > 0:
idx = self.clamp_lower_upper_idx

new_state[:, :, idx] = self.clamp_lower_upper(
self.inverse_clamp_lower_upper(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

# Softplus clamps between ]a,infty[
if self.clamp_lower_idx.numel() > 0:
idx = self.clamp_lower_idx

new_state[:, :, idx] = self.clamp_lower(
self.inverse_clamp_lower(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

# Softplus clamps between ]-infty,b[
if self.clamp_upper_idx.numel() > 0:
idx = self.clamp_upper_idx

new_state[:, :, idx] = self.clamp_upper(
self.inverse_clamp_upper(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

return new_state

def get_num_mesh(self):
"""
Compute number of mesh nodes from loaded features,
Expand Down Expand Up @@ -173,5 +360,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
# Rescale with one-step difference statistics
rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean

# Residual connection for full state
return prev_state + rescaled_delta_mean, pred_std
# Clamp values to valid range (also add the delta to the previous state)
new_state = self.get_clamped_new_state(rescaled_delta_mean, prev_state)

return new_state, pred_std
9 changes: 9 additions & 0 deletions tests/datastore_examples/mdp/danra_100m_winds/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,12 @@ training:
weights:
u100m: 1.0
v100m: 1.0
t2m: 1.0
r2m: 1.0
output_clamping:
lower:
t2m: 0.0
r2m: 0
upper:
r2m: 100.0
u100m: 100.0
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inputs:
dims: [x, y]
target_output_variable: state

danra_surface:
danra_surface_forcing:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
dims: [time, x, y]
variables:
Expand All @@ -73,6 +73,24 @@ inputs:
name_format: "{var_name}"
target_output_variable: forcing

danra_surface:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
dims: [time, x, y]
variables:
- r2m
- t2m
dim_mapping:
time:
method: rename
dim: time
grid_index:
method: stack
dims: [x, y]
state_feature:
method: stack_variables_by_var_name
name_format: "{var_name}"
target_output_variable: state

danra_lsm:
path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr
dims: [x, y]
Expand Down
Loading
Loading