@@ -941,3 +941,42 @@ def get_model_sparsity(self):
941
941
if self ._compression_manager is not None :
942
942
sparsity = self ._compression_manager .model .report_sparsity ()[- 1 ]
943
943
return sparsity
944
+
945
+ def _maybe_log_save_evaluate (self , tr_loss , model , trial , epoch , ignore_keys_for_eval ):
946
+ # TODO : can be removed once transformers >= v4.38.0
947
+ if self .control .should_log and self .state .global_step > self ._globalstep_last_logged :
948
+ if is_torch_tpu_available ():
949
+ xm .mark_step ()
950
+
951
+ logs : Dict [str , float ] = {}
952
+
953
+ # all_gather + mean() to get average loss over all processes
954
+ tr_loss_scalar = self ._nested_gather (tr_loss ).mean ().item ()
955
+
956
+ # reset tr_loss to zero
957
+ tr_loss -= tr_loss
958
+
959
+ logs ["loss" ] = round (tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ), 4 )
960
+ logs ["learning_rate" ] = self ._get_learning_rate ()
961
+
962
+ self ._total_loss_scalar += tr_loss_scalar
963
+ self ._globalstep_last_logged = self .state .global_step
964
+ self .store_flos ()
965
+
966
+ self .log (logs )
967
+
968
+ metrics = None
969
+ if self .control .should_evaluate :
970
+ metrics = self .evaluate (ignore_keys = ignore_keys_for_eval )
971
+ self ._report_to_hp_search (trial , self .state .global_step , metrics )
972
+
973
+ # Run delayed LR scheduler now that metrics are populated
974
+ if isinstance (self .lr_scheduler , torch .optim .lr_scheduler .ReduceLROnPlateau ):
975
+ metric_to_check = self .args .metric_for_best_model
976
+ if not metric_to_check .startswith ("eval_" ):
977
+ metric_to_check = f"eval_{ metric_to_check } "
978
+ self .lr_scheduler .step (metrics [metric_to_check ])
979
+
980
+ if self .control .should_save :
981
+ self ._save_checkpoint (model , trial , metrics = metrics )
982
+ self .control = self .callback_handler .on_save (self .args , self .state , self .control )
0 commit comments