Skip to content

Commit 827bb9e

Browse files
committed
fix inv_freq tracing based on latest changes in model
1 parent 7ff7e12 commit 827bb9e

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

optimum/exporters/openvino/model_configs.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
LlamaModelPatcher,
3838
MixtralModelPatcher,
3939
QwenModelPatcher,
40+
Phi3ModelPatcher,
4041
)
4142

4243

@@ -451,4 +452,7 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
451452
library_name="transformers",
452453
)
453454
class Phi3OpenVINOConfig(PhiOnnxConfig):
454-
pass
455+
def patch_model_for_export(
456+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
457+
) -> "ModelPatcher":
458+
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+13
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,16 @@ def __init__(
623623
# model has first inference buffers initialization
624624
if hasattr(self._model.lm_head, "first_flag"):
625625
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
626+
627+
628+
class Phi3ModelPatcher(DecoderModelPatcher):
629+
def __enter__(self):
630+
super().__enter__()
631+
632+
# init inv_freq for torchscript tracing
633+
for layer in self._model.model.layers:
634+
if layer.self_attn.rotary_emb.inv_freq is None:
635+
rotary_emb = layer.self_attn.rotary_emb
636+
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
637+
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
638+
)

0 commit comments

Comments
 (0)