-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] schedulefree #351
[Fix] schedulefree #351
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first of all, thanks for your contributions! i commented on your reviews, and please feel free to leave comments or ask questions :)
I haven't answered everything yet, but I'll review them maybe later today.
yes. as you said, bias_correction2 is applied before ams_bound. however, in my implementation, i thought that it brought a little bit of efficient computation, dividing however, on second thought, there seems no advantage in doing that, so gonna refactor that later. thanks for brining this up :) |
@hatonosuke I checked all of your work and reviews, and seems like there's kinda a lot to change and refactor. So, I'm thinking of closing this PR and opening another branch, then working there based on this PR. (of course, the credit all goes to you) Here are things to work on:
thanks for your contributions! and if you have more questions or reviews, please feel free to leave here :) |
I got it. |
Problem (Why?)
The schedule-free optimizer in pytorch-optimizer differs from the original implementations shown below.
https://github.com/facebookresearch/schedule_free
Below are the differences and questions.
Could you please review this?
weight decay
Since weight decay needs to be applied appropriately to both z and y, processing is required differently from general Adam.
I implemented processing for schedulefree.
lr include beta2 bias correction
Original implementations of AdamW/RAdam schedulefree do not include beta2 bias correction in the weights used to calculate the mean of parameters. I removed bias_corretion2 from lr.
use_palm
Original implementations does not have use_palm.
I don't know what kind of behavior use_palm is trying to implement because it's not documented.
Therefore, I predicted beta2 to be a scheduling method of (1 - step**-0.8) in the following paper.
https://arxiv.org/abs/2204.02311
(1 - step**-0.8) is 0 when step is 1, I thought it could be implemented by simply removing bias correction and setting beta2 to (1-step**-0.8).
However, as in the PaLM implementation below, it is possible to always change betaes without any problems.
I think that if this implementation run while changing beta2, the behavior of this PR will be slightly different.
https://github.com/conceptofmind/PaLM/blob/c95d8c42fdb957eb8ec506954f74dee9aa184b5c/palm/stable_adamw.py#L71
RAdam and use_plam
When use_palm is True, scheduled beta2 must also be used to calculate the coefficients of RAdam.
However, I do not have the ability to derive a reasonable RAdam calculation formula when scheduling beta2.
ScheduleFreeRAdam when degenerated_to_sgd=False
I change that ScheduleFreeRAdam set lr to 0 when n_sma < 4 and degenerated_to_sgd is False.
And I add tests with degenerated_to_sgd=False.
Other changes (bug fixes, small refactors)
I applied bias_correction2 before ams_bound.
I believe this behavior is correct in AMSGrad, but elsewhere in this repository bias_correction2 is applied after ams_bound.
Is this the intended behavior?
And since weight_decouple has no meaning in SGD, I removed weight_decouple ScheduleFreeSGD.