Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] schedulefree #351

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 35 additions & 48 deletions pytorch_optimizer/optimizer/schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class ScheduleFreeSGD(BaseOptimizer):
:param lr: float. learning rate.
:param momentum: float. momentum factor, must be between 0 and 1 exclusive.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param r: float. use polynomial weighting in the average with power r.
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
Expand All @@ -30,7 +29,6 @@ def __init__(
lr: float = 1.0,
momentum: float = 0.9,
weight_decay: float = 0.0,
weight_decouple: bool = True,
fixed_decay: bool = False,
r: float = 0.0,
weight_lr_power: float = 2.0,
Expand All @@ -47,7 +45,6 @@ def __init__(
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'r': r,
'weight_lr_power': weight_lr_power,
Expand Down Expand Up @@ -132,21 +129,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
if len(state) == 0:
state['z'] = p.clone()

self.apply_weight_decay(
p=p,
grad=grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)

z = state['z']

grad.mul_(lr)
grad.add_(p, alpha=group['weight_decay'] * (1.0 if group['fixed_decay'] else lr))

p.lerp_(z, weight=checkpoint)
p.add_(grad, alpha=lr * (momentum * (1.0 - checkpoint) - 1))
p.add_(grad, alpha=momentum * (1.0 - checkpoint) - 1)

z.sub_(grad, alpha=lr)
z.sub_(grad)

return loss

Expand Down Expand Up @@ -259,9 +250,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:

beta1, beta2 = group['betas']

bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
bias_correction2: float = 1.0 - beta2 ** group['step']

lr: float = group['lr'] * schedule * bias_correction2_sq
lr: float = group['lr'] * schedule
lr_max = group['lr_max'] = max(lr, group['lr_max'])

weight = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
Expand All @@ -271,9 +262,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:

if group['use_palm']:
beta2: float = 1.0 - group['step'] ** -0.8
debias: float = (1.0 - beta2) / (1.0 - beta2 ** group['step'])
else:
debias: float = beta2
# unnecessary bias correction when PaLM beta2 scheduling
bias_correction2 = 1.0

for p in group['params']:
if p.grad is None:
Expand All @@ -289,31 +279,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['z'] = p.clone()
state['exp_avg_sq'] = torch.zeros_like(p)

self.apply_weight_decay(
p=p,
grad=grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)
if not group['weight_decouple']:
grad.add_(p, alpha=group['weight_decay'])

z, exp_avg_sq = state['z'], state['exp_avg_sq']
exp_avg_sq.mul_(debias).addcmul_(grad, grad, value=1.0 - debias)

exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
de_nom = self.apply_ams_bound(
ams_bound=group['ams_bound'],
exp_avg_sq=exp_avg_sq,
exp_avg_sq=exp_avg_sq.div(bias_correction2),
max_exp_avg_sq=state.get('max_exp_avg_sq', None),
eps=group['eps'],
)

grad.div_(de_nom)
grad.mul_(lr)
if group['weight_decouple']:
grad.add_(p, alpha=group['weight_decay'] * (1.0 if group['fixed_decay'] else lr))

p.lerp_(z, weight=checkpoint)
p.add_(grad, alpha=lr * (beta1 * (1.0 - checkpoint) - 1))
p.add_(grad, alpha=beta1 * (1.0 - checkpoint) - 1)

z.sub_(grad, alpha=lr)
z.sub_(grad)

return loss

Expand Down Expand Up @@ -428,6 +414,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
n_sma_threshold=4,
degenerated_to_sgd=group['degenerated_to_sgd'],
)
if n_sma > 4:
# cancel bias correction2
lr = lr / bias_correction2_sq
elif lr < 0.:
# n_sma < 4.0 and degenerated_to_sgd is False
lr = 0.0

lr_max = group['lr_max'] = max(lr, group['lr_max'])

Expand All @@ -436,13 +428,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:

checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

adaptive_y_lr: float = lr * (beta1 * (1.0 - checkpoint) - 1.0)

if group['use_palm']:
beta2: float = 1.0 - group['step'] ** -0.8
debias: float = (1.0 - beta2) / (1.0 - beta2 ** group['step'])
else:
debias: float = beta2
# unnecessary bias correction when PaLM beta2 scheduling
bias_correction2_sq = 1.0

for p in group['params']:
if p.grad is None:
Expand All @@ -458,25 +447,23 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['z'] = p.clone()
state['exp_avg_sq'] = torch.zeros_like(p)

if not group['weight_decouple']:
grad.add_(p, alpha=group['weight_decay'])

z, exp_avg_sq = state['z'], state['exp_avg_sq']
exp_avg_sq.mul_(debias).addcmul_(grad, grad, value=1.0 - debias)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

if n_sma > 4.0:
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
grad.div_(de_nom)

self.apply_weight_decay(
p=p,
grad=grad,
lr=lr,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)
grad.mul_(lr)
if group['weight_decouple']:
grad.add_(p, alpha=group['weight_decay'] * (1.0 if group['fixed_decay'] else lr))

p.lerp_(z, weight=checkpoint)
p.add_(grad, alpha=adaptive_y_lr)
p.add_(grad, alpha=beta1 * (1.0 - checkpoint) - 1.0)

z.sub_(grad, alpha=lr)
z.sub_(grad)

return loss
4 changes: 3 additions & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e-2, 'weight_decay': 1e-3, 'use_palm': True}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3, 'use_palm': True}, 5),
(ScheduleFreeRAdam, {'lr': 5e0, 'weight_decay': 1e-3}, 10),
(ScheduleFreeRAdam, {'lr': 5e0, 'weight_decay': 1e-3, 'use_palm': True}, 10),
(ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 5),
(ScheduleFreeRAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'use_palm': True, 'degenerated_to_sgd': True}, 5),
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
Expand Down