11
11
from torch import optim
12
12
from torch .optim .lr_scheduler import _LRScheduler
13
13
14
- AVAI_SCH = {'single_step' , 'multi_step' , 'cosine' , 'warmup' , 'cosine_cycle' , 'reduce_on_plateau' , 'onecycle' }
14
+ AVAI_SCH = {'single_step' , 'multi_step' , 'cosine' , 'warmup' , 'cosine_cycle' ,
15
+ 'reduce_on_plateau_delayed' , 'reduce_on_plateau' , 'onecycle' }
15
16
16
17
def build_lr_scheduler (optimizer , lr_scheduler , base_scheduler , ** kwargs ):
17
18
if lr_scheduler == 'warmup' :
@@ -37,7 +38,8 @@ def _build_scheduler(optimizer,
37
38
max_lr = 0.1 ,
38
39
patience = 5 ,
39
40
lr_decay_factor = 100 ,
40
- pct_start = 0.3 ):
41
+ pct_start = 0.3 ,
42
+ epoch_delay = 0 ):
41
43
42
44
init_learning_rate = [param_group ['lr' ] for param_group in optimizer .param_groups ]
43
45
if lr_scheduler not in AVAI_SCH :
@@ -96,7 +98,21 @@ def _build_scheduler(optimizer,
96
98
lb_lr = [lr / lr_decay_factor for lr in init_learning_rate ]
97
99
epoch_treshold = max (int (max_epoch * 0.75 ) - warmup , 1 ) # 75% of the training - warmup epochs
98
100
scheduler = ReduceLROnPlateauV2 (optimizer , epoch_treshold , factor = gamma , patience = patience ,
99
- threshold = 2e-4 , verbose = True , min_lr = lb_lr )
101
+ threshold = 2e-4 , verbose = True , min_lr = min_lr )
102
+ elif lr_scheduler == 'reduce_on_plateau_delayed' :
103
+ if epoch_delay < 0 :
104
+ raise ValueError (f'epoch_delay = { epoch_delay } should be greater than zero' )
105
+
106
+ if max_epoch < epoch_delay :
107
+ raise ValueError (f'max_epoch param = { max_epoch } should be greater than'
108
+ f' epoch_delay param = { epoch_delay } ' )
109
+
110
+ if epoch_delay < warmup :
111
+ raise ValueError (f'warmap param = { warmup } should be less than'
112
+ f' epoch_delay param = { epoch_delay } ' )
113
+ epoch_treshold = max (int (max_epoch * 0.75 ) - epoch_delay , 1 ) # 75% of the training - skipped epochs
114
+ scheduler = ReduceLROnPlateauV2Delayed (optimizer , epoch_treshold , epoch_delay , factor = gamma ,
115
+ patience = patience , threshold = 2e-4 , verbose = True , min_lr = min_lr )
100
116
else :
101
117
raise ValueError ('Unknown scheduler: {}' .format (lr_scheduler ))
102
118
@@ -275,3 +291,30 @@ class OneCycleLR(optim.lr_scheduler.OneCycleLR):
275
291
@property
276
292
def warmup_finished (self ):
277
293
return self .last_epoch >= self ._schedule_phases [0 ]['end_step' ]
294
+
295
+
296
+ class ReduceLROnPlateauV2Delayed (ReduceLROnPlateauV2 ):
297
+ """
298
+ ReduceOnPlateuV2 scheduler which starts working only
299
+ after certain amount of epochs specified by epoch delay param.
300
+ Useful when compression algorithms is applying to prevent
301
+ lr drop before full model compression. Warmup included into epoch_delay.
302
+ """
303
+ def __init__ (self ,
304
+ optimizer : optim .Optimizer ,
305
+ epoch_treshold : int ,
306
+ epoch_delay : int ,
307
+ ** kwargs ) -> None :
308
+
309
+ super ().__init__ (optimizer , epoch_treshold , ** kwargs )
310
+ self ._epoch_delay = epoch_delay
311
+
312
+ def step (self , metrics , epoch = None ):
313
+ # If there was less than self._epoch_delay epochs
314
+ # just update epochs counter
315
+ if self .last_epoch <= self ._epoch_delay :
316
+ if epoch is None :
317
+ epoch = self .last_epoch + 1
318
+ self .last_epoch = epoch
319
+ else :
320
+ super ().step (metrics , epoch = epoch )
0 commit comments