Skip to content

Commit 12ce691

Browse files
committed
enable phi
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 63bee4e commit 12ce691

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

optimum/intel/ipex/modeling_base.py

+14
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def _reorder_cache(self, *args, **kwargs):
317317
def prepare_inputs_for_generation(self, *args, **kwargs):
318318
return self.model.prepare_inputs_for_generation(*args, **kwargs)
319319

320+
def _supports_num_logits_to_keep(self) -> bool:
321+
"""
322+
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
323+
to save memory. Checking it in this way allows to avoid using a new model attribute.
324+
"""
325+
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
326+
320327
def generate(self, *args, **kwargs):
321328
if self._add_patch and kwargs.get("assistant_model", None):
322329
raise ValueError(
@@ -427,6 +434,13 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
427434
def get_encoder(self, *args, **kwargs):
428435
return self.model.get_encoder(*args, **kwargs)
429436

437+
def _supports_num_logits_to_keep(self) -> bool:
438+
"""
439+
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
440+
to save memory. Checking it in this way allows to avoid using a new model attribute.
441+
"""
442+
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
443+
430444
def _init_warmup(self):
431445
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
432446
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)

tests/ipex/test_modeling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
239239
"mistral",
240240
"llama",
241241
"llama2",
242-
# "phi",
242+
"phi",
243243
"distilgpt2",
244244
"mpt",
245245
"opt",

tests/ipex/test_pipelines.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase):
6666
"mistral",
6767
"mpt",
6868
"opt",
69+
"phi",
6970
"qwen2",
7071
)
7172
QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = (

0 commit comments

Comments
 (0)