Skip to content

Commit e60872c

Browse files
committed
add support output_attentions
1 parent 5aa30ed commit e60872c

File tree

1 file changed

+83
-10
lines changed

1 file changed

+83
-10
lines changed

optimum/exporters/openvino/model_patcher.py

+83-10
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,25 @@ def _baichuan13b_atten_forward(
640640
attention_mask = attention_mask[:, :, -key_states.shape[-2] :, :]
641641
key_states = torch.cat([past_key_value[0], key_states], dim=2)
642642
value_states = torch.cat([past_key_value[1], value_states], dim=2)
643+
if not output_attentions:
644+
past_key_value = (key_states, value_states) if use_cache else None
645+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
646+
attn_weights = None
647+
else:
648+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
649+
650+
if attention_mask is not None:
651+
if q_len == 1: # inference with cache
652+
if len(attention_mask.size()) == 4:
653+
attention_mask = attention_mask[:, :, -1:, :]
654+
else:
655+
attention_mask = attention_mask[:, -1:, :]
656+
attn_weights = attn_weights + attention_mask
657+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
658+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
659+
attn_output = torch.matmul(attn_weights, value_states)
643660

644-
past_key_value = (key_states, value_states) if use_cache else None
645-
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
646661
attn_output = attn_output.transpose(1, 2)
647-
attn_weights = None
648662
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
649663
attn_output = self.o_proj(attn_output)
650664

@@ -708,7 +722,7 @@ def __exit__(self, exc_type, exc_value, traceback):
708722
layer.self_attn.forward = layer.self_attn._orig_forward
709723

710724

711-
def _mpt_attention_forward(
725+
def _mpt_sdpa_attention_forward(
712726
self,
713727
hidden_states: torch.Tensor,
714728
position_bias: torch.Tensor,
@@ -759,18 +773,73 @@ def _mpt_attention_forward(
759773
return attn_output, None, past_key_value
760774

761775

776+
def _mpt_block_forward(
777+
self,
778+
hidden_states: torch.Tensor,
779+
position_bias: torch.Tensor,
780+
attention_mask: torch.Tensor,
781+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
782+
use_cache: bool = False,
783+
output_attentions: bool = False,
784+
):
785+
# hidden_states: [batch_size, seq_length, hidden_size]
786+
# Layer norm at the beginning of the transformer layer.
787+
layernorm_output = self.norm_1(hidden_states)
788+
789+
residual = hidden_states
790+
791+
if not output_attentions:
792+
# Self attention.
793+
attn_outputs, attn_weights, past_key_value = self.attn(
794+
layernorm_output,
795+
position_bias=position_bias,
796+
attention_mask=attention_mask,
797+
past_key_value=layer_past,
798+
)
799+
else:
800+
attn_outputs, attn_weights, past_key_value = self.attn._orig_forward(
801+
layernorm_output,
802+
position_bias=position_bias,
803+
attention_mask=attention_mask,
804+
past_key_value=layer_past,
805+
)
806+
807+
hidden_states = self.resid_attn_dropout(attn_outputs) + residual
808+
809+
layernorm_output = self.norm_2(hidden_states)
810+
811+
# Get residual
812+
residual = hidden_states
813+
814+
# MLP.
815+
output = self.ffn(layernorm_output, residual)
816+
outputs = (output,)
817+
818+
if use_cache:
819+
outputs += (past_key_value,)
820+
821+
if output_attentions:
822+
outputs += (attn_weights,)
823+
824+
return outputs
825+
826+
762827
class MPTModelPatcher(DecoderModelPatcher):
763828
def __enter__(self):
764829
super().__enter__()
765830

766831
if is_torch_version(">=", "2.1.0"):
767832
for block in self._model.transformer.blocks:
833+
block._orig_forward = block.forward
834+
block.forward = types.MethodType(_mpt_block_forward, block)
768835
block.attn._orig_forward = block.attn.forward
769-
block.attn.forward = types.MethodType(_mpt_attention_forward, block.attn)
836+
block.attn.forward = types.MethodType(_mpt_sdpa_attention_forward, block.attn)
770837

771838
def __exit__(self, exc_type, exc_value, traceback):
772839
super().__exit__(exc_type, exc_value, traceback)
773840
for block in self._model.transformer.blocks:
841+
if hasattr(block, "_orig_forward"):
842+
block.forward = block._orig_forward
774843
if hasattr(block.attn, "_orig_forward"):
775844
block.attn.forward = block.attn._orig_forward
776845

@@ -848,17 +917,21 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
848917

849918
key_states = repeat_kv(key_states, self.num_key_value_groups)
850919
value_states = repeat_kv(value_states, self.num_key_value_groups)
920+
if not output_attentions:
921+
attn_output = torch.nn.functional.scaled_dot_product_attention(
922+
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
923+
)
924+
attn_weights = None
925+
else:
926+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
927+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
928+
attn_output = torch.matmul(attn_weights, value_states)
851929

852-
attn_output = torch.nn.functional.scaled_dot_product_attention(
853-
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
854-
)
855930
attn_output = attn_output.transpose(1, 2).contiguous()
856931
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
857932

858933
attn_output = self.wo(attn_output)
859934

860-
attn_weights = None
861-
862935
return attn_output, attn_weights, past_key_value
863936

864937

0 commit comments

Comments
 (0)