Skip to content

Commit 51d9c08

Browse files
Support transformers v4.49 logits_to_keep for IPEX (huggingface#1188)
* fix logits_to_keep Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix typo Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add comments Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 93ee486 commit 51d9c08

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

optimum/intel/ipex/modeling_base.py

+18
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,17 @@ def _reorder_cache(self, *args, **kwargs):
353353
def prepare_inputs_for_generation(self, *args, **kwargs):
354354
return self.model.prepare_inputs_for_generation(*args, **kwargs)
355355

356+
def _supports_logits_to_keep(self) -> bool:
357+
"""
358+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359+
to save memory. Checking it in this way allows to avoid using a new model attribute.
360+
"""
361+
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
362+
356363
def _supports_num_logits_to_keep(self) -> bool:
357364
"""
365+
Will be deprecated after we no longer support transformers < 4.49
366+
358367
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
359368
to save memory. Checking it in this way allows to avoid using a new model attribute.
360369
"""
@@ -470,8 +479,17 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
470479
def get_encoder(self, *args, **kwargs):
471480
return self.model.get_encoder(*args, **kwargs)
472481

482+
def _supports_logits_to_keep(self) -> bool:
483+
"""
484+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
485+
to save memory. Checking it in this way allows to avoid using a new model attribute.
486+
"""
487+
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
488+
473489
def _supports_num_logits_to_keep(self) -> bool:
474490
"""
491+
Will be deprecated after we no longer support transformers < 4.49
492+
475493
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
476494
to save memory. Checking it in this way allows to avoid using a new model attribute.
477495
"""

0 commit comments

Comments
 (0)