Skip to content

Commit a255a08

Browse files
committed
buichuan sdpa
1 parent 37f2094 commit a255a08

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

optimum/exporters/openvino/model_patcher.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,71 @@ def _baichuan13b_atten_forward(
673673
return attn_output, attn_weights, past_key_value
674674

675675

676+
def _baichuan7b_attn_forward(
677+
self,
678+
hidden_states: torch.Tensor,
679+
attention_mask: Optional[torch.Tensor] = None,
680+
position_ids: Optional[torch.LongTensor] = None,
681+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
682+
output_attentions: bool = False,
683+
use_cache: bool = False,
684+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
685+
def rotate_half(x):
686+
"""Rotates half the hidden dims of the input."""
687+
x1 = x[..., : x.shape[-1] // 2]
688+
x2 = x[..., x.shape[-1] // 2 :]
689+
return torch.cat((-x2, x1), dim=-1)
690+
691+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
692+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
693+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
694+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
695+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
696+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
697+
q_embed = (q * cos) + (rotate_half(q) * sin)
698+
k_embed = (k * cos) + (rotate_half(k) * sin)
699+
return q_embed, k_embed
700+
701+
bsz, q_len, _ = hidden_states.size()
702+
703+
proj = self.W_pack(hidden_states)
704+
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
705+
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
706+
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
707+
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
708+
709+
kv_seq_len = key_states.shape[-2]
710+
if past_key_value is not None:
711+
kv_seq_len += past_key_value[0].shape[-2]
712+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
713+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
714+
# [bsz, nh, t, hd]
715+
716+
if past_key_value is not None:
717+
# reuse k, v, self_attention
718+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
719+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
720+
721+
past_key_value = (key_states, value_states) if use_cache else None
722+
if not output_attentions:
723+
attn_weights = None
724+
attn_output = F.scaled_dot_product_attention(
725+
query_states, key_states, value_states, attn_mask=attention_mask, scale=1 / math.sqrt(self.head_dim)
726+
)
727+
else:
728+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
729+
if attention_mask is not None:
730+
attn_weights = attn_weights + attention_mask
731+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
732+
attn_output = torch.matmul(attn_weights, value_states)
733+
734+
attn_output = attn_output.transpose(1, 2)
735+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
736+
attn_output = self.o_proj(attn_output)
737+
738+
return attn_output, attn_weights, past_key_value
739+
740+
676741
class BaichuanModelPatcher(DecoderModelPatcher):
677742
def __init__(
678743
self,
@@ -720,13 +785,18 @@ def forward(
720785
for layer in self._model.model.layers:
721786
layer.self_attn._orig_forward = layer.self_attn.forward
722787
layer.self_attn.forward = types.MethodType(_baichuan13b_atten_forward, layer.self_attn)
788+
else:
789+
for layer in self._model.model.layers:
790+
layer.self_attn._orig_forward = layer.self_attn.forward
791+
layer.self_attn.forward = types.MethodType(_baichuan7b_attn_forward, layer.self_attn)
723792

724793
def __exit__(self, exc_type, exc_value, traceback):
725794
super().__exit__(exc_type, exc_value, traceback)
726795
if hasattr(self._model, "_orig_forward"):
727796
self._model.forward = self._model._orig_forward
728797

729-
for layer in self._model.model.layers:
798+
for layer in self._model.model.layers:
799+
if hasattr(layer.self_attn, "_orig_forward"):
730800
layer.self_attn.forward = layer.self_attn._orig_forward
731801

732802

0 commit comments

Comments
 (0)