Skip to content

Commit 8c2b787

Browse files
eaidovaecharlaix
andauthored
Add sdpa for phi3 openvino model (#705)
* add sdpa for phi3 openvino model * fix pkv filling according model code * Update optimum/exporters/openvino/model_patcher.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * import helpers from phi3 if available --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent c743886 commit 8c2b787

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

optimum/exporters/openvino/model_configs.py

+6
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def patch_model_for_export(
485485
library_name="transformers",
486486
)
487487
class Phi3OpenVINOConfig(PhiOnnxConfig):
488+
DUMMY_INPUT_GENERATOR_CLASSES = (
489+
MistralDummyPastKeyValuesGenerator,
490+
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
491+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
492+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
493+
488494
def patch_model_for_export(
489495
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
490496
) -> "ModelPatcher":

optimum/exporters/openvino/model_patcher.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -951,15 +951,107 @@ def __exit__(self, exc_type, exc_value, traceback):
951951
block.attention.forward = block.attention._orig_forward
952952

953953

954+
# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426
955+
def _phi3_self_attn_sdpa_forward(
956+
self,
957+
hidden_states: torch.Tensor,
958+
attention_mask: Optional[torch.Tensor] = None,
959+
position_ids: Optional[torch.LongTensor] = None,
960+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
961+
output_attentions: bool = False,
962+
use_cache: bool = False,
963+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
964+
if output_attentions:
965+
return self._orig_forward(
966+
hidden_states=hidden_states,
967+
attention_mask=attention_mask,
968+
position_ids=position_ids,
969+
past_key_value=past_key_value,
970+
output_attentions=output_attentions,
971+
use_cache=use_cache,
972+
)
973+
974+
# TO DO: remove llama imports when transformers with phi3 support will be released
975+
try:
976+
from transformers.models.phi3.modelling_phi3 import apply_rotary_pos_emb, repeat_kv
977+
except ImportError:
978+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
979+
980+
bsz, q_len, _ = hidden_states.size()
981+
982+
qkv = self.qkv_proj(hidden_states)
983+
query_pos = self.num_heads * self.head_dim
984+
query_states = qkv[..., :query_pos]
985+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
986+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
987+
988+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
989+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
990+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
991+
kv_seq_len = key_states.shape[-2]
992+
if past_key_value is not None:
993+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
994+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
995+
996+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
997+
998+
if past_key_value is not None:
999+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1000+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1001+
1002+
key_states = repeat_kv(key_states, self.num_key_value_groups)
1003+
value_states = repeat_kv(value_states, self.num_key_value_groups)
1004+
1005+
if attention_mask is not None:
1006+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1007+
raise ValueError(
1008+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1009+
)
1010+
1011+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1012+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
1013+
if query_states.device.type == "cuda" and attention_mask is not None:
1014+
query_states = query_states.contiguous()
1015+
key_states = key_states.contiguous()
1016+
value_states = value_states.contiguous()
1017+
1018+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1019+
query_states,
1020+
key_states,
1021+
value_states,
1022+
attn_mask=attention_mask,
1023+
dropout_p=self.attention_dropout if self.training else 0.0,
1024+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1025+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
1026+
)
1027+
1028+
attn_output = attn_output.transpose(1, 2).contiguous()
1029+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1030+
1031+
attn_output = self.o_proj(attn_output)
1032+
1033+
return attn_output, None, past_key_value
1034+
1035+
9541036
class Phi3ModelPatcher(DecoderModelPatcher):
9551037
def __enter__(self):
9561038
super().__enter__()
957-
9581039
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
9591040
# init inv_freq for torchscript tracing
9601041
for layer in self._model.model.layers:
1042+
if is_torch_version(">=", "2.1.0"):
1043+
orig_self_attn_fwd = layer.self_attn.forward
1044+
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
1045+
layer.self_attn._orig_forward = orig_self_attn_fwd
1046+
9611047
if layer.self_attn.rotary_emb.inv_freq is None:
9621048
rotary_emb = layer.self_attn.rotary_emb
9631049
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
9641050
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
9651051
)
1052+
1053+
def __exit__(self, exc_type, exc_value, traceback):
1054+
super().__exit__(exc_type, exc_value, traceback)
1055+
for layer in self._model.model.layers:
1056+
if hasattr(layer.self_attn, "_orig_forward"):
1057+
layer.self_attn.forward = layer.self_attn._orig_forward

0 commit comments

Comments
 (0)