Skip to content

Commit 41876fb

Browse files
eaidovaecharlaix
andauthored
Add phi3 export openvino (#686)
* support hpi3 export openvino * fix inv_freq tracing based on latest changes in model * add test model * Update optimum/exporters/openvino/model_patcher.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 33fc7b7 commit 41876fb

File tree

4 files changed

+37
-2
lines changed

4 files changed

+37
-2
lines changed

optimum/exporters/openvino/model_configs.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig
22+
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
@@ -37,6 +37,7 @@
3737
GemmaModelPatcher,
3838
LlamaModelPatcher,
3939
MixtralModelPatcher,
40+
Phi3ModelPatcher,
4041
QwenModelPatcher,
4142
)
4243

@@ -440,6 +441,24 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
440441
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
441442

442443

444+
@register_in_tasks_manager(
445+
"phi3",
446+
*[
447+
"feature-extraction",
448+
"feature-extraction-with-past",
449+
"text-generation",
450+
"text-generation-with-past",
451+
"text-classification",
452+
],
453+
library_name="transformers",
454+
)
455+
class Phi3OpenVINOConfig(PhiOnnxConfig):
456+
def patch_model_for_export(
457+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
458+
) -> "ModelPatcher":
459+
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)
460+
461+
443462
class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
444463
def __init__(
445464
self,

optimum/exporters/openvino/model_patcher.py

+14
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,17 @@ def __init__(
623623
# model has first inference buffers initialization
624624
if hasattr(self._model.lm_head, "first_flag"):
625625
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
626+
627+
628+
class Phi3ModelPatcher(DecoderModelPatcher):
629+
def __enter__(self):
630+
super().__enter__()
631+
632+
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
633+
# init inv_freq for torchscript tracing
634+
for layer in self._model.model.layers:
635+
if layer.self_attn.rotary_emb.inv_freq is None:
636+
rotary_emb = layer.self_attn.rotary_emb
637+
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
638+
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
639+
)

tests/openvino/test_modeling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
535535
"stablelm",
536536
"starcoder2",
537537
"phi",
538+
"phi3",
538539
"internlm2",
539540
"orion",
540541
"falcon",
541542
"falcon-40b",
542543
)
543544
GENERATION_LENGTH = 100
544-
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion")
545+
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion", "phi3")
545546

546547
@parameterized.expand(SUPPORTED_ARCHITECTURES)
547548
def test_compare_to_transformers(self, model_arch):

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"pegasus": "hf-internal-testing/tiny-random-pegasus",
7979
"pix2struct": "fxmarty/pix2struct-tiny-random",
8080
"phi": "echarlaix/tiny-random-PhiForCausalLM",
81+
"phi3": "katuni4ka/tiny-random-phi3",
8182
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
8283
"qwen": "katuni4ka/tiny-random-qwen",
8384
"qwen2": "Qwen/Qwen1.5-0.5B",

0 commit comments

Comments
 (0)