Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functionality to Apply Constraints to Predictions #92

Merged
merged 23 commits into from
Feb 7, 2025

Conversation

SimonKamuk
Copy link
Contributor

@SimonKamuk SimonKamuk commented Nov 29, 2024

Describe your changes

This change implements a method for constraining model output to a specified valid range. This is useful to ensure reliable model output for variables the cannot physically fall outside of this range - such as absolute temperature which must be positive or relative humidity which must be between 0 and 100%.

This is implemented by using the config.yaml for specifying valid ranges for each parameter, where each variable defaults to not having a limit. A scaled sigmoid function is then applied to the prediction for variables that have both an upper and lower limit, and a scaled softplus is used for variables that must be above or below a certain threshold.

Issue Link

closes #19

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@joeloskarsson
Copy link
Collaborator

@SimonKamuk did you figure out a solution to #19 (comment) ? Sorry to comment already now, I know this is work in progress, I'm just curious about this :)

Thinking about it a bit more, I realized that one solution would be to just always apply the skip-connection before any activation function. So that the skip-connection is for the non-clamped values. E.g. since both sigmoid and softplus is invertible you could do something like $f(f^{-1}(X^t) + \text{model}())$ (although there are probably better ways to implement it.

@SimonKamuk
Copy link
Contributor Author

@SimonKamuk did you figure out a solution to #19 (comment) ? Sorry to comment already now, I know this is work in progress, I'm just curious about this :)

Thinking about it a bit more, I realized that one solution would be to just always apply the skip-connection before any activation function. So that the skip-connection is for the non-clamped values. E.g. since both sigmoid and softplus is invertible you could do something like f ( f − 1 ( X t ) + m o d e l ( ) ) (although there are probably better ways to implement it.

That's quite a neat way to do it!

I initially applied the clamping function to the new state, f(X_t+model()) but then realized this would mess with the residual connection, so what I implemented is basically this:

  • First I scale the clamping values according to the state normalization: clamping between [a,b] becomes [(a-mean)/std, (b-mean)/std]
  • Then when I apply the clamping activation functions during training i subtract the previous state from the limits, so it becomes [(a-mean)/std - X_t, (b-mean)/std - X_t]. But the function is only applied to the delta outputted from the model: X_(t+1)=X_t+f(model())

I wonder if there is an argument for using this method compared to your suggestion?

@joeloskarsson
Copy link
Collaborator

This is quite interesting and I'm trying to get some better understanding of what might be a good approach. I made this desmos interactive plot to try to wrap my head around it: https://www.desmos.com/calculator/tnrd6igkqb

Your method definitely also works correctly. I realized that this (clamping the delta to new bounds) is equivalent to not having any skip-connection to the previous state. Below I'm ignoring the mean and std rescaling for simplicity, and assume we want to rescale variable to $[l,u]$. For a clamping function $c$, applied to some model output $x$

$$ c(x,l,u) = (u-l)\sigma(x) + l $$

where $\sigma$ is sigmoid, the model output is

$$ X^t + \delta = X^t + c(x, l - X^t, u - X^t) = X^t + (u - X^t - (l - X^t))\sigma(x) + l - X^t = (u-l)\sigma(x) + l = c(x,l,u) $$

So this practically removes the dependence on the previous state.

The difference to my "invert-previous-clamping" approach is that that would equate skip connections on the logits, before any activation function (clamping here). So that does maintain some dependence. I'm not sure if this is important. A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

This really relates to if one wants to use the skip-connections over states or not. I think it would eventually be nice to have this as an option. Maybe these two clamping strategies should correspond to the selection of that then?

@SimonKamuk
Copy link
Contributor Author

Oh wow that's a good catch. I agree that we want the option to keep the skip connection, so my method is not the way to go - even if it was applied after unrolling, because then we would still be removing the final skip-connection at every ar step (although the first ones would indeed be preserved). I'll have a go at implementing your inverse method

@SimonKamuk
Copy link
Contributor Author

I implemented your suggestion, but I added the constraint that the input (previous state) to the inverse sigmoid and softplus are clamped hard to avoid the inverse functions from returning inf - this would have prevented the model from ever outputting anything other than 1 if say relative humidity was clamped to [0,1] and the previous state was already 1.

$\sigma(\sigma^{−1}(1)+model()) = \sigma(\infty+model()) = \sigma(\infty) = 1$

But this should not be an issue, as the clamping is only applied to the previous state, not the model output itself, so the gradients can still be computed.

@joeloskarsson
Copy link
Collaborator

Hmm, yes that's an important consideration. Good that you thought about this. I'm guessing that the situation could occur that a variable is >=0, and an initial state where it is = 0 exists.

Note that gradients do go through also the previous state (we don't detach these from the computational graph), not just the model output, when we unroll during training. So the clamping does still impact gradients. However, I don't think this should be a problem in practice and this solution should work fine. In the case that the previous state comes from a model prediction during rollout, it should not be possible for it to hit exactly 0/1, so the clamping would anyhow not have an effect.

@SimonKamuk SimonKamuk marked this pull request as ready for review December 13, 2024 10:27
@SimonKamuk
Copy link
Contributor Author

I still don't understand why this last test is failing, could it be a resource issue? If anyone knows what is going on I'm all ears 😄 but as far as my changes are concerned I think this is ready for review

@joeloskarsson
Copy link
Collaborator

The test failing is probably not directly related to code, but to resources. I see Error: Process completed with exit code 247., but I'm not sure what that means (what exactly I should look up this exit code for). It seems to happen when testing the training loop, so might be related to memory or other resources.

I've added to my TODO list to give this a proper review. A couple high-level consideration in the meantime:

  1. Does most of this functionality (in particular the clamping prep/application methods) belong to BaseGraphModel, or should it sit already in ARModel? I am thinking that any model (even hypothethical non-graph models) would need these methods.
  2. When I described my idea for this I thought of it as inverting the activation function clamping from the previous time step. This is now how this is implemented. This does however mean that we have to clamp and unclamp these states all the time. The inverse clamp is a bit of unnecessary compute really. Another way to do this would be (from my comment above)

A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

What's your thoughts on this? Are there good reasons to do it the "inversion"-way? The extra unnecessary compute is quite small, so maybe not an issue really, but doing the inverse-clamping is a bit more complicated and less transparent in showing that this is applying skip connections on pre-activation representations.

changed prepare_clamping_parames to prepare_clamping_params
@SimonKamuk
Copy link
Contributor Author

SimonKamuk commented Dec 16, 2024

The test failing is probably not directly related to code, but to resources. I see Error: Process completed with exit code 247., but I'm not sure what that means (what exactly I should look up this exit code for). It seems to happen when testing the training loop, so might be related to memory or other resources.

I've added to my TODO list to give this a proper review. A couple high-level consideration in the meantime:

  1. Does most of this functionality (in particular the clamping prep/application methods) belong to BaseGraphModel, or should it sit already in ARModel? I am thinking that any model (even hypothethical non-graph models) would need these methods.
  2. When I described my idea for this I thought of it as inverting the activation function clamping from the previous time step. This is now how this is implemented. This does however mean that we have to clamp and unclamp these states all the time. The inverse clamp is a bit of unnecessary compute really. Another way to do this would be (from my comment above)

A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

What's your thoughts on this? Are there good reasons to do it the "inversion"-way? The extra unnecessary compute is quite small, so maybe not an issue really, but doing the inverse-clamping is a bit more complicated and less transparent in showing that this is applying skip connections on pre-activation representations.

  1. I've added my changes to BaseGraphModel because the predict_step method is not implemented in ARModel. I did consider putting it in ARModel, but I figured if someone went and made another model with a different predict_step (i.e. without the skip connection), then clamp_prediction would need to change. I could move prepare_clamping_params and clamp_prediction to ARModel, and then add a comment about clamp_prediction assuming a model with a skip connection, if you prefer? Or maybe just move the prepare_clamping_params to ARModel?

  2. My gut feeling was that the extra compute would be negligible, but maybe I should actually test what the impact is. Regarding whether to put it in predict_step or unroll_prediction I think I just felt that it was more clear for the model to predict physically consistent values at every time step. As you say it should not matter much for the loss, as only the output of unroll_prediction is fed to the loss. But if the clamping is applied at each prediction_step, then the model would only ever receive valid inputs and return valid outputs, an then wouldn't need to learn to interpret what a humidity above 100% means, which could possibly help with model accuracy.

SimonKamuk and others added 2 commits December 16, 2024 13:57
Added description of clamping feature in config.yaml
@joeloskarsson
Copy link
Collaborator

Thanks for the clarifications above @SimonKamuk ! You make some good points that I did not think about. In particular, as you write, we need to consider what actually goes into the model at each time step, which depends on when the clamping is applied. I need to think that over + then give this a full review. That will have to be in 2025 though, so just letting you know to not expect my review on this until after new years 😃

@ole-dmi
Copy link

ole-dmi commented Jan 8, 2025

I think the clamping solution and implementation looks good and that we are almost ready to close this PR.

@SimonKamuk and I just had a conversation where he explained the $f(f^{-1}(X^t) + \text{model}())$. In particular I was puzzled about the $f^{-1}(X^t)$ term. Now I see that the output of the $model(X^t)$ is also in the inverse space and it makes sense to me.

We also talked about how this solution resembles/relates to methods in constraint optimization, like penalty methods and Lagrange multipliers, e.g. by adding a penalty term to the loss function $loss(x) + penalty(x)$.

For preparation to the review / conversion with Simon I looked at:
An interpretive constrained linear model for ResNet and MgNet
How to Encode Constraints to the Output of Neural Networks

It interesting how this solution performs in a real example, which we will evaluate once the model is up and running on Gefion.

The failing test seems to be due to a memory problem. Can we reduce the memory usage of the test?

I think we need to document the clamping function $f(f^{-1}(X^t) + \text{model}())$ somewhere appropriate, maybe in README.md or a docstring in the code.

@SimonKamuk
Copy link
Contributor Author

I think the clamping solution and implementation looks good and that we are almost ready to close this PR.

@SimonKamuk and I just had a conversation where he explained the f ( f − 1 ( X t ) + model ( ) ) . In particular I was puzzled about the f − 1 ( X t ) term. Now I see that the output of the m o d e l ( X t ) is also in the inverse space and it makes sense to me.

We also talked about how this solution resembles/relates to methods in constraint optimization, like penalty methods and Lagrange multipliers, e.g. by adding a penalty term to the loss function l o s s ( x ) + p e n a l t y ( x ) .

For preparation to the review / conversion with Simon I looked at: An interpretive constrained linear model for ResNet and MgNet How to Encode Constraints to the Output of Neural Networks

It interesting how this solution performs in a real example, which we will evaluate once the model is up and running on Gefion.

The failing test seems to be due to a memory problem. Can we reduce the memory usage of the test?

I think we need to document the clamping function f ( f − 1 ( X t ) + model ( ) ) somewhere appropriate, maybe in README.md or a docstring in the code.

I added the expression in the docstring - I am not too sure about the notation. Do we have any convention for mathematical expressions in the code?

@joeloskarsson
Copy link
Collaborator

Not really, my view is:

  1. If we actually need math notation in the code, best to explain all non-trivial notation in the same comment it is used.
  2. Could be good to keep notation somewhat close to existing publications (in README), just to make it easier for people reading both.

@joeloskarsson joeloskarsson added this to the v0.4.0 milestone Jan 12, 2025
Copy link
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went over this more carefully now and this is looking really good! 😄 I just left some small comments, mostly about documentation. The added tests are great as well.

SimonKamuk and others added 5 commits January 14, 2025 15:42
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
@joeloskarsson
Copy link
Collaborator

joeloskarsson commented Jan 16, 2025

@SimonKamuk Is this ready for me to look over again after the changes? (see also #92 (comment)). You can also go ahead and add an entry to the changelog.

@SimonKamuk
Copy link
Contributor Author

@SimonKamuk Is this ready for me to look over again after the changes? (see also #92 (comment)). You can also go ahead and add an entry to the changelog.

@joeloskarsson yes, I think it is ready now 😄

Copy link
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is perfect now. Wonderful job on this 🥳 Ready to merge from my side, however:

  1. As this should go in in v0.4.0, we should wait until the v0.3.0 release is done (prepare v0.3.0 release #98) before merging. This might mean that you have to do a quick update to the changelog file to avoid conflicts (I did not think of this when reminding about the changelog above).
  2. There looks to be some test issue that might require some attention.

@SimonKamuk
Copy link
Contributor Author

All tests pass now, should I just go ahead and merge @joeloskarsson @ole-dmi ?

@joeloskarsson
Copy link
Collaborator

Go ahead!

@SimonKamuk SimonKamuk merged commit f342487 into mllam:main Feb 7, 2025
8 checks passed
@SimonKamuk SimonKamuk deleted the feat/prediction_constraints branch February 7, 2025 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Add Functionality to Apply Constraints to Predictions
3 participants