Skip to content

Commit 47af979

Browse files
committed
add sdpa for phi3 openvino model
1 parent e6fadb1 commit 47af979

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

optimum/exporters/openvino/model_patcher.py

+89-1
Original file line numberDiff line numberDiff line change
@@ -951,15 +951,103 @@ def __exit__(self, exc_type, exc_value, traceback):
951951
block.attention.forward = block.attention._orig_forward
952952

953953

954+
# Adapted from Phi3Attention.forward
955+
def _phi3_self_attn_sdpa_forward(
956+
self,
957+
hidden_states: torch.Tensor,
958+
attention_mask: Optional[torch.Tensor] = None,
959+
position_ids: Optional[torch.LongTensor] = None,
960+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
961+
output_attentions: bool = False,
962+
use_cache: bool = False,
963+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
964+
if output_attentions:
965+
return self._orig_forward(
966+
hidden_states=hidden_states,
967+
attention_mask=attention_mask,
968+
position_ids=position_ids,
969+
past_key_value=past_key_value,
970+
output_attentions=output_attentions,
971+
use_cache=use_cache,
972+
)
973+
974+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
975+
976+
bsz, q_len, _ = hidden_states.size()
977+
978+
qkv = self.qkv_proj(hidden_states)
979+
query_pos = self.num_heads * self.head_dim
980+
query_states = qkv[..., :query_pos]
981+
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
982+
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
983+
984+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
985+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
986+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
987+
kv_seq_len = key_states.shape[-2]
988+
if past_key_value is not None:
989+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
990+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
991+
992+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
993+
994+
if past_key_value is not None:
995+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
996+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
997+
998+
key_states = repeat_kv(key_states, self.num_key_value_groups)
999+
value_states = repeat_kv(value_states, self.num_key_value_groups)
1000+
1001+
if attention_mask is not None:
1002+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1003+
raise ValueError(
1004+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1005+
)
1006+
1007+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1008+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
1009+
if query_states.device.type == "cuda" and attention_mask is not None:
1010+
query_states = query_states.contiguous()
1011+
key_states = key_states.contiguous()
1012+
value_states = value_states.contiguous()
1013+
1014+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1015+
query_states,
1016+
key_states,
1017+
value_states,
1018+
attn_mask=attention_mask,
1019+
dropout_p=self.attention_dropout if self.training else 0.0,
1020+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1021+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
1022+
)
1023+
1024+
attn_output = attn_output.transpose(1, 2).contiguous()
1025+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1026+
1027+
attn_output = self.o_proj(attn_output)
1028+
1029+
return attn_output, None, past_key_value
1030+
1031+
9541032
class Phi3ModelPatcher(DecoderModelPatcher):
9551033
def __enter__(self):
9561034
super().__enter__()
957-
9581035
# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
9591036
# init inv_freq for torchscript tracing
9601037
for layer in self._model.model.layers:
1038+
if is_torch_version(">=", "2.1.0"):
1039+
orig_self_attn_fwd = layer.self_attn.forward
1040+
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
1041+
layer.self_attn._orig_forward = orig_self_attn_fwd
1042+
9611043
if layer.self_attn.rotary_emb.inv_freq is None:
9621044
rotary_emb = layer.self_attn.rotary_emb
9631045
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
9641046
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
9651047
)
1048+
1049+
def __exit__(self, exc_type, exc_value, traceback):
1050+
super().__exit__(exc_type, exc_value, traceback)
1051+
for layer in self._model.model.layers:
1052+
if hasattr(layer.self_attn, "_orig_forward"):
1053+
layer.self_attn.forward = layer.self_attn._orig_forward

0 commit comments

Comments
 (0)