Skip to content

Commit eb31cd2

Browse files
committed
fix
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 2d65a72 commit eb31cd2

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

optimum/intel/ipex/modeling_base.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,14 @@ def _supports_logits_to_keep(self) -> bool:
358358
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359359
to save memory. Checking it in this way allows to avoid using a new model attribute.
360360
"""
361-
logits_to_keep_name = "logits_to_keep" if is_transformers_version(">", "4.49") else "num_logits_to_keep"
362-
return logits_to_keep_name in set(inspect.signature(self.model.forward).parameters.keys())
361+
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
362+
363+
def _supports_num_logits_to_keep(self) -> bool:
364+
"""
365+
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
366+
to save memory. Checking it in this way allows to avoid using a new model attribute.
367+
"""
368+
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
363369

364370
def generate(self, *args, **kwargs):
365371
if self._add_patch and kwargs.get("assistant_model", None):
@@ -478,6 +484,13 @@ def _supports_logits_to_keep(self) -> bool:
478484
"""
479485
return "logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
480486

487+
def _supports_num_logits_to_keep(self) -> bool:
488+
"""
489+
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
490+
to save memory. Checking it in this way allows to avoid using a new model attribute.
491+
"""
492+
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
493+
481494
def _init_warmup(self):
482495
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
483496
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)

0 commit comments

Comments
 (0)