Skip to content

Commit 09f4620

Browse files
committed
fix compatibility for latest transformers release
1 parent c7e228f commit 09f4620

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

optimum/intel/neural_compressor/trainer.py

+39
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,42 @@ def get_model_sparsity(self):
941941
if self._compression_manager is not None:
942942
sparsity = self._compression_manager.model.report_sparsity()[-1]
943943
return sparsity
944+
945+
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
946+
# TODO : can be removed once transformers >= v4.38.0
947+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
948+
if is_torch_tpu_available():
949+
xm.mark_step()
950+
951+
logs: Dict[str, float] = {}
952+
953+
# all_gather + mean() to get average loss over all processes
954+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
955+
956+
# reset tr_loss to zero
957+
tr_loss -= tr_loss
958+
959+
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
960+
logs["learning_rate"] = self._get_learning_rate()
961+
962+
self._total_loss_scalar += tr_loss_scalar
963+
self._globalstep_last_logged = self.state.global_step
964+
self.store_flos()
965+
966+
self.log(logs)
967+
968+
metrics = None
969+
if self.control.should_evaluate:
970+
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
971+
self._report_to_hp_search(trial, self.state.global_step, metrics)
972+
973+
# Run delayed LR scheduler now that metrics are populated
974+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
975+
metric_to_check = self.args.metric_for_best_model
976+
if not metric_to_check.startswith("eval_"):
977+
metric_to_check = f"eval_{metric_to_check}"
978+
self.lr_scheduler.step(metrics[metric_to_check])
979+
980+
if self.control.should_save:
981+
self._save_checkpoint(model, trial, metrics=metrics)
982+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
INSTALL_REQUIRE = [
1515
"torch>=1.11",
1616
"optimum>=1.17.0",
17-
"transformers>=4.26.0",
17+
"transformers>=4.29.0,<4.39.0",
1818
"datasets>=1.4.0",
1919
"sentencepiece",
2020
"scipy",

0 commit comments

Comments
 (0)