Skip to content
This repository was archived by the owner on Apr 17, 2023. It is now read-only.

Commit 0ab5334

Browse files
ReduceLROnDelayScheduler
1 parent b84ca21 commit 0ab5334

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

torchreid/optim/lr_scheduler.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from torch import optim
1212
from torch.optim.lr_scheduler import _LRScheduler
1313

14-
AVAI_SCH = {'single_step', 'multi_step', 'cosine', 'warmup', 'cosine_cycle', 'reduce_on_plateau', 'onecycle'}
14+
AVAI_SCH = {'single_step', 'multi_step', 'cosine', 'warmup', 'cosine_cycle',
15+
'reduce_on_plateau_delayed', 'reduce_on_plateau', 'onecycle'}
1516

1617
def build_lr_scheduler(optimizer, lr_scheduler, base_scheduler, **kwargs):
1718
if lr_scheduler == 'warmup':
@@ -37,7 +38,8 @@ def _build_scheduler(optimizer,
3738
max_lr=0.1,
3839
patience=5,
3940
lr_decay_factor=100,
40-
pct_start=0.3):
41+
pct_start=0.3,
42+
epoch_delay=0):
4143

4244
init_learning_rate = [param_group['lr'] for param_group in optimizer.param_groups]
4345
if lr_scheduler not in AVAI_SCH:
@@ -96,7 +98,21 @@ def _build_scheduler(optimizer,
9698
lb_lr = [lr / lr_decay_factor for lr in init_learning_rate]
9799
epoch_treshold = max(int(max_epoch * 0.75) - warmup, 1) # 75% of the training - warmup epochs
98100
scheduler = ReduceLROnPlateauV2(optimizer, epoch_treshold, factor=gamma, patience=patience,
99-
threshold=2e-4, verbose=True, min_lr=lb_lr )
101+
threshold=2e-4, verbose=True, min_lr=min_lr)
102+
elif lr_scheduler == 'reduce_on_plateau_delayed':
103+
if epoch_delay < 0:
104+
raise ValueError(f'epoch_delay = {epoch_delay} should be greater than zero')
105+
106+
if max_epoch < epoch_delay:
107+
raise ValueError(f'max_epoch param = {max_epoch} should be greater than'
108+
f' epoch_delay param = {epoch_delay}')
109+
110+
if epoch_delay < warmup:
111+
raise ValueError(f'warmap param = {warmup} should be less than'
112+
f' epoch_delay param = {epoch_delay}')
113+
epoch_treshold = max(int(max_epoch * 0.75) - epoch_delay, 1) # 75% of the training - skipped epochs
114+
scheduler = ReduceLROnPlateauV2Delayed(optimizer, epoch_treshold, epoch_delay, factor=gamma,
115+
patience=patience, threshold=2e-4, verbose=True, min_lr=min_lr)
100116
else:
101117
raise ValueError('Unknown scheduler: {}'.format(lr_scheduler))
102118

@@ -275,3 +291,30 @@ class OneCycleLR(optim.lr_scheduler.OneCycleLR):
275291
@property
276292
def warmup_finished(self):
277293
return self.last_epoch >= self._schedule_phases[0]['end_step']
294+
295+
296+
class ReduceLROnPlateauV2Delayed(ReduceLROnPlateauV2):
297+
"""
298+
ReduceOnPlateuV2 scheduler which starts working only
299+
after certain amount of epochs specified by epoch delay param.
300+
Useful when compression algorithms is applying to prevent
301+
lr drop before full model compression. Warmup included into epoch_delay.
302+
"""
303+
def __init__(self,
304+
optimizer: optim.Optimizer,
305+
epoch_treshold: int,
306+
epoch_delay: int,
307+
**kwargs) -> None:
308+
309+
super().__init__(optimizer, epoch_treshold, **kwargs)
310+
self._epoch_delay = epoch_delay
311+
312+
def step(self, metrics, epoch=None):
313+
# If there was less than self._epoch_delay epochs
314+
# just update epochs counter
315+
if self.last_epoch <= self._epoch_delay:
316+
if epoch is None:
317+
epoch = self.last_epoch + 1
318+
self.last_epoch = epoch
319+
else:
320+
super().step(metrics, epoch=epoch)

0 commit comments

Comments
 (0)