Skip to content

Commit d9af9fe

Browse files
committed
fix logits_to_keep
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 6cceb30 commit d9af9fe

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

optimum/intel/ipex/modeling_base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,12 @@ def _reorder_cache(self, *args, **kwargs):
353353
def prepare_inputs_for_generation(self, *args, **kwargs):
354354
return self.model.prepare_inputs_for_generation(*args, **kwargs)
355355

356-
def _supports_num_logits_to_keep(self) -> bool:
356+
def _supports_logits_to_keep(self) -> bool:
357357
"""
358-
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
358+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
359359
to save memory. Checking it in this way allows to avoid using a new model attribute.
360360
"""
361-
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
361+
return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
362362

363363
def generate(self, *args, **kwargs):
364364
if self._add_patch and kwargs.get("assistant_model", None):
@@ -470,12 +470,12 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
470470
def get_encoder(self, *args, **kwargs):
471471
return self.model.get_encoder(*args, **kwargs)
472472

473-
def _supports_num_logits_to_keep(self) -> bool:
473+
def _supports_logits_to_keep(self) -> bool:
474474
"""
475-
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
475+
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
476476
to save memory. Checking it in this way allows to avoid using a new model attribute.
477477
"""
478-
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
478+
return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
479479

480480
def _init_warmup(self):
481481
inputs = prepare_jit_inputs(self.model, self.export_feature, False)

0 commit comments

Comments
 (0)