Skip to content

Commit d854863

Browse files
committed
fix accuracy
1 parent c674492 commit d854863

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

optimum/exporters/openvino/model_patcher.py

+47-5
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,18 @@ def _mpt_attention_forward(
722722
else:
723723
past_key_value = (key_states, value_states)
724724

725-
attention_mask_sdpa = torch.ones(attention_mask.shape, dtype=query_states.dtype)
725+
key_length = key_states.shape[-2]
726+
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
727+
attention_mask_sdpa = torch.ones(
728+
(query_states.shape[0], query_states.shape[1], query_states.shape[2], key_states.shape[2]),
729+
dtype=query_states.dtype,
730+
)
731+
if position_bias is not None:
732+
position_bias_query_index = max(0, position_bias.size(1) - query_length)
733+
position_bias_key_index = max(0, position_bias.size(2) - key_length)
734+
735+
position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
736+
attention_mask_sdpa += position_bias
726737
attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min)
727738
context_states = torch.nn.functional.scaled_dot_product_attention(
728739
query_states,
@@ -732,6 +743,7 @@ def _mpt_attention_forward(
732743
dropout_p=self.attn_dropout_p,
733744
scale=self.softmax_scale,
734745
)
746+
735747
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
736748
attn_output = self.out_proj(context_states)
737749

@@ -764,17 +776,47 @@ def _internlm_attention_forward(
764776
use_cache: bool = False,
765777
**kwargs,
766778
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
767-
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
779+
# from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
780+
from einops import rearrange
781+
782+
def rotate_half(x):
783+
"""Rotates half the hidden dims of the input."""
784+
x1 = x[..., : x.shape[-1] // 2]
785+
x2 = x[..., x.shape[-1] // 2 :]
786+
return torch.cat((-x2, x1), dim=-1)
787+
788+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
789+
"""Applies Rotary Position Embedding to the query and key tensors."""
790+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
791+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
792+
q_embed = (q * cos) + (rotate_half(q) * sin)
793+
k_embed = (k * cos) + (rotate_half(k) * sin)
794+
return q_embed, k_embed
795+
796+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
797+
"""
798+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
799+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
800+
"""
801+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
802+
if n_rep == 1:
803+
return hidden_states
804+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
805+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
768806

769807
bsz, q_len, _ = hidden_states.size()
770808

771809
qkv_states = self.wqkv(hidden_states)
772810

773-
qkv_states = qkv_states.reshape(
774-
qkv_states.shape[0], qkv_states.shape[1], -1, 2 + self.num_key_values_groups, self.head_dim
811+
qkv_states = rearrange(
812+
qkv_states,
813+
"b q (h gs d) -> b q h gs d",
814+
gs=2 + self.num_key_value_groups,
815+
d=self.head_dim,
775816
)
817+
776818
query_states = qkv_states[..., : self.num_key_value_groups, :]
777-
query_states = query_states.reshape(query_states.shape[0], query_states.shape[1], -1, query_states.shape[-1])
819+
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
778820
key_states = qkv_states[..., -2, :]
779821
value_states = qkv_states[..., -1, :]
780822

0 commit comments

Comments
 (0)