Skip to content

Commit 3db832a

Browse files
committed
support export more models
1 parent 8c2b787 commit 3db832a

File tree

4 files changed

+330
-5
lines changed

4 files changed

+330
-5
lines changed

optimum/exporters/openvino/model_configs.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,18 @@
4141
from optimum.utils.normalized_config import NormalizedTextConfig
4242

4343
from .model_patcher import (
44+
AquilaModelPatcher,
4445
BaichuanModelPatcher,
4546
ChatGLMModelPatcher,
4647
GemmaModelPatcher,
47-
InternLMPatcher,
48+
InternLM2Patcher,
49+
InternLMModelPatcher,
4850
LlamaModelPatcher,
4951
MixtralModelPatcher,
5052
MPTModelPatcher,
5153
Phi3ModelPatcher,
5254
QwenModelPatcher,
55+
XverseModelPatcher,
5356
)
5457

5558

@@ -445,7 +448,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
445448
def patch_model_for_export(
446449
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
447450
) -> "ModelPatcher":
448-
return InternLMPatcher(self, model, model_kwargs=model_kwargs)
451+
return InternLM2Patcher(self, model, model_kwargs=model_kwargs)
449452

450453

451454
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
@@ -653,3 +656,45 @@ class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig):
653656
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
654657
num_attention_heads="attention_heads", hidden_size="d_model"
655658
)
659+
660+
661+
@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers")
662+
class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
663+
DEFAULT_ONNX_OPSET = 14
664+
665+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
666+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
667+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
668+
669+
def patch_model_for_export(
670+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
671+
) -> "ModelPatcher":
672+
return AquilaModelPatcher(self, model, model_kwargs=model_kwargs)
673+
674+
675+
@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers")
676+
class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
677+
DEFAULT_ONNX_OPSET = 14
678+
679+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
680+
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
681+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
682+
683+
def patch_model_for_export(
684+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
685+
) -> "ModelPatcher":
686+
return XverseModelPatcher(self, model, model_kwargs=model_kwargs)
687+
688+
689+
@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers")
690+
class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
691+
DEFAULT_ONNX_OPSET = 14
692+
693+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
694+
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
695+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
696+
697+
def patch_model_for_export(
698+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
699+
) -> "ModelPatcher":
700+
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

+271-3
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def __exit__(self, exc_type, exc_value, traceback):
844844
block.attn.forward = block.attn._orig_forward
845845

846846

847-
def _internlm_attention_forward(
847+
def _internlm2_attention_forward(
848848
self,
849849
hidden_states: torch.Tensor,
850850
attention_mask: Optional[torch.Tensor] = None,
@@ -935,14 +935,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
935935
return attn_output, attn_weights, past_key_value
936936

937937

938-
class InternLMPatcher(DecoderModelPatcher):
938+
class InternLM2Patcher(DecoderModelPatcher):
939939
def __enter__(self):
940940
super().__enter__()
941941

942942
if is_torch_version(">=", "2.1.0"):
943943
for block in self._model.model.layers:
944944
block.attention._orig_forward = block.attention.forward
945-
block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention)
945+
block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention)
946946

947947
def __exit__(self, exc_type, exc_value, traceback):
948948
super().__exit__(exc_type, exc_value, traceback)
@@ -1055,3 +1055,271 @@ def __exit__(self, exc_type, exc_value, traceback):
10551055
for layer in self._model.model.layers:
10561056
if hasattr(layer.self_attn, "_orig_forward"):
10571057
layer.self_attn.forward = layer.self_attn._orig_forward
1058+
1059+
1060+
def _aquila_self_attn_sdpa_forward(
1061+
self,
1062+
hidden_states: torch.Tensor,
1063+
attention_mask: Optional[torch.Tensor] = None,
1064+
position_ids: Optional[torch.LongTensor] = None,
1065+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
1066+
output_attentions: bool = False,
1067+
use_cache: bool = False,
1068+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1069+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1070+
"""
1071+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
1072+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1073+
"""
1074+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1075+
if n_rep == 1:
1076+
return hidden_states
1077+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
1078+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1079+
1080+
def rotate_half(x):
1081+
"""Rotates half the hidden dims of the input."""
1082+
x1 = x[..., : x.shape[-1] // 2]
1083+
x2 = x[..., x.shape[-1] // 2 :]
1084+
return torch.cat((-x2, x1), dim=-1)
1085+
1086+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1087+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
1088+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
1089+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
1090+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1091+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1092+
q_embed = (q * cos) + (rotate_half(q) * sin)
1093+
k_embed = (k * cos) + (rotate_half(k) * sin)
1094+
return q_embed, k_embed
1095+
1096+
if output_attentions:
1097+
return self._orig_forward(
1098+
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
1099+
)
1100+
bsz, q_len, _ = hidden_states.size()
1101+
1102+
if self.config.pretraining_tp > 1:
1103+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
1104+
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
1105+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
1106+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
1107+
1108+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
1109+
query_states = torch.cat(query_states, dim=-1)
1110+
1111+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
1112+
key_states = torch.cat(key_states, dim=-1)
1113+
1114+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
1115+
value_states = torch.cat(value_states, dim=-1)
1116+
1117+
else:
1118+
query_states = self.q_proj(hidden_states)
1119+
key_states = self.k_proj(hidden_states)
1120+
value_states = self.v_proj(hidden_states)
1121+
1122+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1123+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1124+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1125+
1126+
kv_seq_len = key_states.shape[-2]
1127+
if past_key_value is not None:
1128+
kv_seq_len += past_key_value[0].shape[-2]
1129+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1130+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1131+
1132+
if past_key_value is not None:
1133+
# reuse k, v, self_attention
1134+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
1135+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
1136+
1137+
past_key_value = (key_states, value_states) if use_cache else None
1138+
1139+
# repeat k/v heads if n_kv_heads < n_heads
1140+
key_states = repeat_kv(key_states, self.num_key_value_groups)
1141+
value_states = repeat_kv(value_states, self.num_key_value_groups)
1142+
1143+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1144+
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
1145+
)
1146+
attn_weights = None
1147+
1148+
attn_output = attn_output.transpose(1, 2).contiguous()
1149+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1150+
1151+
if self.config.pretraining_tp > 1:
1152+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
1153+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
1154+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
1155+
else:
1156+
attn_output = self.o_proj(attn_output)
1157+
1158+
return attn_output, attn_weights, past_key_value
1159+
1160+
1161+
class AquilaModelPatcher(DecoderModelPatcher):
1162+
def __enter__(self):
1163+
super().__enter__()
1164+
for layer in self._model.model.layers:
1165+
if is_torch_version(">=", "2.1.0"):
1166+
orig_self_attn_fwd = layer.self_attn.forward
1167+
layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn)
1168+
layer.self_attn._orig_forward = orig_self_attn_fwd
1169+
1170+
def __exit__(self, exc_type, exc_value, traceback):
1171+
super().__exit__(exc_type, exc_value, traceback)
1172+
for layer in self._model.model.layers:
1173+
if hasattr(layer.self_attn, "_orig_forward"):
1174+
layer.self_attn.forward = layer.self_attn._orig_forward
1175+
1176+
1177+
def _xverse_self_attn_sdpa_forward(
1178+
self,
1179+
hidden_states: torch.Tensor,
1180+
attention_mask: Optional[torch.Tensor] = None,
1181+
position_ids: Optional[torch.LongTensor] = None,
1182+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
1183+
output_attentions: bool = False,
1184+
use_cache: bool = False,
1185+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1186+
def rotate_half(x):
1187+
"""Rotates half the hidden dims of the input."""
1188+
x1 = x[..., : x.shape[-1] // 2]
1189+
x2 = x[..., x.shape[-1] // 2 :]
1190+
return torch.cat((-x2, x1), dim=-1)
1191+
1192+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1193+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
1194+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
1195+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
1196+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1197+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1198+
q_embed = (q * cos) + (rotate_half(q) * sin)
1199+
k_embed = (k * cos) + (rotate_half(k) * sin)
1200+
return q_embed, k_embed
1201+
1202+
if output_attentions:
1203+
return self._orig_forward(
1204+
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
1205+
)
1206+
bsz, q_len, _ = hidden_states.size()
1207+
1208+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1209+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1210+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1211+
1212+
kv_seq_len = key_states.shape[-2]
1213+
if past_key_value is not None:
1214+
kv_seq_len += past_key_value[0].shape[-2]
1215+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1216+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1217+
# [bsz, nh, t, hd]
1218+
1219+
if past_key_value is not None:
1220+
# reuse k, v, self_attention
1221+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
1222+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
1223+
1224+
past_key_value = (key_states, value_states) if use_cache else None
1225+
1226+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1227+
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
1228+
)
1229+
attn_weights = None
1230+
1231+
attn_output = attn_output.transpose(1, 2).contiguous()
1232+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1233+
1234+
attn_output = self.o_proj(attn_output)
1235+
1236+
return attn_output, attn_weights, past_key_value
1237+
1238+
1239+
def _internlm_self_attn_sdpa_forward(
1240+
self,
1241+
hidden_states: torch.Tensor,
1242+
attention_mask: Optional[torch.Tensor] = None,
1243+
position_ids: Optional[torch.LongTensor] = None,
1244+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
1245+
output_attentions: bool = False,
1246+
use_cache: bool = False,
1247+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1248+
def rotate_half(x):
1249+
"""Rotates half the hidden dims of the input."""
1250+
x1 = x[..., : x.shape[-1] // 2]
1251+
x2 = x[..., x.shape[-1] // 2 :]
1252+
return torch.cat((-x2, x1), dim=-1)
1253+
1254+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1255+
cos = cos[position_ids].unsqueeze(1)
1256+
sin = sin[position_ids].unsqueeze(1)
1257+
q_embed = (q * cos) + (rotate_half(q) * sin)
1258+
k_embed = (k * cos) + (rotate_half(k) * sin)
1259+
return q_embed, k_embed
1260+
1261+
if output_attentions:
1262+
return self._orig_forward(
1263+
hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache
1264+
)
1265+
1266+
bsz, q_len, _ = hidden_states.size()
1267+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1268+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1269+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1270+
kv_seq_len = key_states.shape[-2]
1271+
if past_key_value is not None:
1272+
kv_seq_len += past_key_value[0].shape[-2]
1273+
1274+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1275+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1276+
1277+
if past_key_value is not None:
1278+
# reuse k, v, self_attention
1279+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
1280+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
1281+
1282+
past_key_value = (key_states, value_states) if use_cache else None
1283+
1284+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1285+
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
1286+
)
1287+
attn_weights = None
1288+
1289+
attn_output = attn_output.transpose(1, 2)
1290+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1291+
1292+
attn_output = self.o_proj(attn_output)
1293+
return attn_output, attn_weights, past_key_value
1294+
1295+
1296+
class XverseModelPatcher(DecoderModelPatcher):
1297+
def __enter__(self):
1298+
super().__enter__()
1299+
for layer in self._model.model.layers:
1300+
if is_torch_version(">=", "2.1.0"):
1301+
orig_self_attn_fwd = layer.self_attn.forward
1302+
layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn)
1303+
layer.self_attn._orig_forward = orig_self_attn_fwd
1304+
1305+
def __exit__(self, exc_type, exc_value, traceback):
1306+
super().__exit__(exc_type, exc_value, traceback)
1307+
for layer in self._model.model.layers:
1308+
if hasattr(layer.self_attn, "_orig_forward"):
1309+
layer.self_attn.forward = layer.self_attn._orig_forward
1310+
1311+
1312+
class InternLMModelPatcher(DecoderModelPatcher):
1313+
def __enter__(self):
1314+
super().__enter__()
1315+
for layer in self._model.model.layers:
1316+
if is_torch_version(">=", "2.1.0"):
1317+
orig_self_attn_fwd = layer.self_attn.forward
1318+
layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn)
1319+
layer.self_attn._orig_forward = orig_self_attn_fwd
1320+
1321+
def __exit__(self, exc_type, exc_value, traceback):
1322+
super().__exit__(exc_type, exc_value, traceback)
1323+
for layer in self._model.model.layers:
1324+
if hasattr(layer.self_attn, "_orig_forward"):
1325+
layer.self_attn.forward = layer.self_attn._orig_forward

0 commit comments

Comments
 (0)