Skip to content

Commit 514f054

Browse files
committed
apply sdpa for mpt and internlm
1 parent 673b88b commit 514f054

File tree

2 files changed

+151
-4
lines changed

2 files changed

+151
-4
lines changed

optimum/exporters/openvino/model_configs.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from transformers.utils import is_tf_available
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
22-
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig
22+
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig, LlamaOnnxConfig, MPTOnnxConfig
2323
from optimum.exporters.tasks import TasksManager
2424
from optimum.utils import DEFAULT_DUMMY_SHAPES
2525
from optimum.utils.input_generators import (
@@ -37,6 +37,8 @@
3737
LlamaModelPatcher,
3838
MixtralModelPatcher,
3939
QwenModelPatcher,
40+
MPTModelPatcher,
41+
InternLMPatcher,
4042
)
4143

4244

@@ -429,6 +431,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
429431
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
430432
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
431433

434+
def patch_model_for_export(
435+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
436+
) -> "ModelPatcher":
437+
return InternLMPatcher(self, model, model_kwargs=model_kwargs)
438+
432439

433440
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
434441
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
@@ -437,3 +444,13 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
437444
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
438445
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
439446
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
447+
448+
449+
@register_in_tasks_manager(
450+
"mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers"
451+
)
452+
class MPTOpenVINOConfig(MPTOnnxConfig):
453+
def patch_model_for_export(
454+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
455+
) -> "ModelPatcher":
456+
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+133-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging as log
16+
import math
1617
import types
1718
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1819

@@ -327,9 +328,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
327328
offset = 0
328329
mask_shape = attention_mask.shape
329330
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
330-
causal_mask[
331-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
332-
] = mask_slice
331+
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
332+
mask_slice
333+
)
333334

334335
if (
335336
self.config._attn_implementation == "sdpa"
@@ -611,3 +612,132 @@ def __init__(
611612
# model has first inference buffers initialization
612613
if hasattr(self._model.lm_head, "first_flag"):
613614
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))
615+
616+
617+
def _mpt_attention_forward(
618+
self,
619+
hidden_states: torch.Tensor,
620+
position_bias: torch.Tensor,
621+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
622+
attention_mask: Optional[torch.Tensor] = None,
623+
):
624+
batch_size, seq_length = hidden_states.shape[:2]
625+
626+
mixed_qkv = self.Wqkv(hidden_states)
627+
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
628+
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
629+
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
630+
value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
631+
632+
if past_key_value is not None:
633+
if len(past_key_value) != 0:
634+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
635+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
636+
past_key_value = (key_states, value_states)
637+
else:
638+
past_key_value = (key_states, value_states)
639+
640+
attention_mask_sdpa = torch.ones(attention_mask.shape, dtype=query_states.dtype)
641+
attention_mask_sdpa.masked_fill_(attention_mask, torch.finfo(query_states.dtype).min)
642+
context_states = torch.nn.functional.scaled_dot_product_attention(
643+
query_states,
644+
key_states,
645+
value_states,
646+
attn_mask=attention_mask_sdpa,
647+
dropout_p=self.attn_dropout_p,
648+
scale=self.softmax_scale,
649+
)
650+
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
651+
attn_output = self.out_proj(context_states)
652+
653+
return attn_output, None, past_key_value
654+
655+
656+
class MPTModelPatcher(DecoderModelPatcher):
657+
def __enter__(self):
658+
super().__enter__()
659+
660+
if is_torch_version(">=", "2.1.0"):
661+
for block in self._model.transformer.blocks:
662+
block.attn._orig_forward = block.attn.forward
663+
block.attn.forward = types.MethodType(_mpt_attention_forward, block.attn)
664+
665+
def __exit__(self, exc_type, exc_value, traceback):
666+
super().__exit__(exc_type, exc_value, traceback)
667+
for block in self._model.transformer.blocks:
668+
if hasattr(block.attn, "_orig_forward"):
669+
block.attn.forward = block.attn._orig_forward
670+
671+
672+
def _internlm_attention_forward(
673+
self,
674+
hidden_states: torch.Tensor,
675+
attention_mask: Optional[torch.Tensor] = None,
676+
position_ids: Optional[torch.LongTensor] = None,
677+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
678+
output_attentions: bool = False,
679+
use_cache: bool = False,
680+
**kwargs,
681+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
682+
683+
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
684+
685+
bsz, q_len, _ = hidden_states.size()
686+
687+
qkv_states = self.wqkv(hidden_states)
688+
689+
qkv_states = qkv_states.reshape(
690+
qkv_states.shape[0], qkv_states.shape[1], -1, 2 + self.num_key_values_groups, self.head_dim
691+
)
692+
query_states = qkv_states[..., : self.num_key_value_groups, :]
693+
query_states = query_states.reshape(query_states.shape[0], query_states.shape[1], -1, query_states.shape[-1])
694+
key_states = qkv_states[..., -2, :]
695+
value_states = qkv_states[..., -1, :]
696+
697+
query_states = query_states.transpose(1, 2)
698+
key_states = key_states.transpose(1, 2)
699+
value_states = value_states.transpose(1, 2)
700+
701+
kv_seq_len = key_states.shape[-2]
702+
if past_key_value is not None:
703+
kv_seq_len += past_key_value[0].shape[-2]
704+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
705+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
706+
707+
if past_key_value is not None:
708+
# reuse k, v, self_attention
709+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
710+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
711+
712+
past_key_value = (key_states, value_states) if use_cache else None
713+
714+
key_states = repeat_kv(key_states, self.num_key_value_groups)
715+
value_states = repeat_kv(value_states, self.num_key_value_groups)
716+
717+
attn_output = torch.nn.functional.scaled_dot_product_attention(
718+
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
719+
)
720+
attn_output = attn_output.transpose(1, 2).contiguous()
721+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
722+
723+
attn_output = self.wo(attn_output)
724+
725+
attn_weights = None
726+
727+
return attn_output, attn_weights, past_key_value
728+
729+
730+
class InternLMPatcher(DecoderModelPatcher):
731+
def __enter__(self):
732+
super().__enter__()
733+
734+
if is_torch_version(">=", "2.1.0"):
735+
for block in self._model.model.layers:
736+
block.attention._orig_forward = block.attention.forward
737+
block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention)
738+
739+
def __exit__(self, exc_type, exc_value, traceback):
740+
super().__exit__(exc_type, exc_value, traceback)
741+
for block in self._model.model.layers:
742+
if hasattr(block.attention, "_orig_forward"):
743+
block.attention.forward = block.attention._orig_forward

0 commit comments

Comments
 (0)