Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] schedulefree #351

Closed
wants to merge 4 commits into from
Closed

Conversation

hatonosuke
Copy link

Problem (Why?)

The schedule-free optimizer in pytorch-optimizer differs from the original implementations shown below.

https://github.com/facebookresearch/schedule_free

Below are the differences and questions.
Could you please review this?

weight decay

Since weight decay needs to be applied appropriately to both z and y, processing is required differently from general Adam.
I implemented processing for schedulefree.

lr include beta2 bias correction

Original implementations of AdamW/RAdam schedulefree do not include beta2 bias correction in the weights used to calculate the mean of parameters. I removed bias_corretion2 from lr.

use_palm

Original implementations does not have use_palm.
I don't know what kind of behavior use_palm is trying to implement because it's not documented.
Therefore, I predicted beta2 to be a scheduling method of (1 - step**-0.8) in the following paper.

https://arxiv.org/abs/2204.02311

(1 - step**-0.8) is 0 when step is 1, I thought it could be implemented by simply removing bias correction and setting beta2 to (1-step**-0.8).

However, as in the PaLM implementation below, it is possible to always change betaes without any problems.
I think that if this implementation run while changing beta2, the behavior of this PR will be slightly different.

https://github.com/conceptofmind/PaLM/blob/c95d8c42fdb957eb8ec506954f74dee9aa184b5c/palm/stable_adamw.py#L71

RAdam and use_plam

When use_palm is True, scheduled beta2 must also be used to calculate the coefficients of RAdam.
However, I do not have the ability to derive a reasonable RAdam calculation formula when scheduling beta2.

ScheduleFreeRAdam when degenerated_to_sgd=False

I change that ScheduleFreeRAdam set lr to 0 when n_sma < 4 and degenerated_to_sgd is False.
And I add tests with degenerated_to_sgd=False.

Other changes (bug fixes, small refactors)

I applied bias_correction2 before ams_bound.
I believe this behavior is correct in AMSGrad, but elsewhere in this repository bias_correction2 is applied after ams_bound.

Is this the intended behavior?

And since weight_decouple has no meaning in SGD, I removed weight_decouple ScheduleFreeSGD.

Copy link
Owner

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first of all, thanks for your contributions! i commented on your reviews, and please feel free to leave comments or ask questions :)

I haven't answered everything yet, but I'll review them maybe later today.

@kozistr
Copy link
Owner

kozistr commented Feb 19, 2025

I applied bias_correction2 before ams_bound. I believe this behavior is correct in AMSGrad, but elsewhere in this repository bias_correction2 is applied after ams_bound.

yes. as you said, bias_correction2 is applied before ams_bound. however, in my implementation, i thought that it brought a little bit of efficient computation, dividing de_nom with sqrt of bias_correction2, which is almost the same result i guess.

however, on second thought, there seems no advantage in doing that, so gonna refactor that later.

thanks for brining this up :)

@kozistr
Copy link
Owner

kozistr commented Feb 20, 2025

@hatonosuke I checked all of your work and reviews, and seems like there's kinda a lot to change and refactor. So, I'm thinking of closing this PR and opening another branch, then working there based on this PR. (of course, the credit all goes to you)

Here are things to work on:

  1. fix weight decay issue (maybe I'm gonna remove the decouple wd option)
  2. update the sf optimizers with the latest one
  3. fix when step size is negative in sfradam
  4. refactor palm part
  5. refactor amsbound part

thanks for your contributions! and if you have more questions or reviews, please feel free to leave here :)

@hatonosuke
Copy link
Author

I got it.
Thank you for your responses!

@kozistr kozistr closed this Feb 22, 2025
@kozistr kozistr mentioned this pull request Feb 22, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants