@@ -844,7 +844,7 @@ def __exit__(self, exc_type, exc_value, traceback):
844
844
block .attn .forward = block .attn ._orig_forward
845
845
846
846
847
- def _internlm_attention_forward (
847
+ def _internlm2_attention_forward (
848
848
self ,
849
849
hidden_states : torch .Tensor ,
850
850
attention_mask : Optional [torch .Tensor ] = None ,
@@ -935,14 +935,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
935
935
return attn_output , attn_weights , past_key_value
936
936
937
937
938
- class InternLMPatcher (DecoderModelPatcher ):
938
+ class InternLM2Patcher (DecoderModelPatcher ):
939
939
def __enter__ (self ):
940
940
super ().__enter__ ()
941
941
942
942
if is_torch_version (">=" , "2.1.0" ):
943
943
for block in self ._model .model .layers :
944
944
block .attention ._orig_forward = block .attention .forward
945
- block .attention .forward = types .MethodType (_internlm_attention_forward , block .attention )
945
+ block .attention .forward = types .MethodType (_internlm2_attention_forward , block .attention )
946
946
947
947
def __exit__ (self , exc_type , exc_value , traceback ):
948
948
super ().__exit__ (exc_type , exc_value , traceback )
@@ -1055,3 +1055,271 @@ def __exit__(self, exc_type, exc_value, traceback):
1055
1055
for layer in self ._model .model .layers :
1056
1056
if hasattr (layer .self_attn , "_orig_forward" ):
1057
1057
layer .self_attn .forward = layer .self_attn ._orig_forward
1058
+
1059
+
1060
+ def _aquila_self_attn_sdpa_forward (
1061
+ self ,
1062
+ hidden_states : torch .Tensor ,
1063
+ attention_mask : Optional [torch .Tensor ] = None ,
1064
+ position_ids : Optional [torch .LongTensor ] = None ,
1065
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
1066
+ output_attentions : bool = False ,
1067
+ use_cache : bool = False ,
1068
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
1069
+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
1070
+ """
1071
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
1072
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1073
+ """
1074
+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
1075
+ if n_rep == 1 :
1076
+ return hidden_states
1077
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
1078
+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
1079
+
1080
+ def rotate_half (x ):
1081
+ """Rotates half the hidden dims of the input."""
1082
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
1083
+ x2 = x [..., x .shape [- 1 ] // 2 :]
1084
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
1085
+
1086
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids ):
1087
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
1088
+ cos = cos .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
1089
+ sin = sin .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
1090
+ cos = cos [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
1091
+ sin = sin [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
1092
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
1093
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
1094
+ return q_embed , k_embed
1095
+
1096
+ if output_attentions :
1097
+ return self ._orig_forward (
1098
+ hidden_states , attention_mask , position_ids , past_key_value , output_attentions , use_cache
1099
+ )
1100
+ bsz , q_len , _ = hidden_states .size ()
1101
+
1102
+ if self .config .pretraining_tp > 1 :
1103
+ key_value_slicing = (self .num_key_value_heads * self .head_dim ) // self .config .pretraining_tp
1104
+ query_slices = self .q_proj .weight .split ((self .num_heads * self .head_dim ) // self .config .pretraining_tp , dim = 0 )
1105
+ key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
1106
+ value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
1107
+
1108
+ query_states = [F .linear (hidden_states , query_slices [i ]) for i in range (self .config .pretraining_tp )]
1109
+ query_states = torch .cat (query_states , dim = - 1 )
1110
+
1111
+ key_states = [F .linear (hidden_states , key_slices [i ]) for i in range (self .config .pretraining_tp )]
1112
+ key_states = torch .cat (key_states , dim = - 1 )
1113
+
1114
+ value_states = [F .linear (hidden_states , value_slices [i ]) for i in range (self .config .pretraining_tp )]
1115
+ value_states = torch .cat (value_states , dim = - 1 )
1116
+
1117
+ else :
1118
+ query_states = self .q_proj (hidden_states )
1119
+ key_states = self .k_proj (hidden_states )
1120
+ value_states = self .v_proj (hidden_states )
1121
+
1122
+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1123
+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
1124
+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
1125
+
1126
+ kv_seq_len = key_states .shape [- 2 ]
1127
+ if past_key_value is not None :
1128
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
1129
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
1130
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
1131
+
1132
+ if past_key_value is not None :
1133
+ # reuse k, v, self_attention
1134
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
1135
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
1136
+
1137
+ past_key_value = (key_states , value_states ) if use_cache else None
1138
+
1139
+ # repeat k/v heads if n_kv_heads < n_heads
1140
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
1141
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
1142
+
1143
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1144
+ query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
1145
+ )
1146
+ attn_weights = None
1147
+
1148
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1149
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
1150
+
1151
+ if self .config .pretraining_tp > 1 :
1152
+ attn_output = attn_output .split (self .hidden_size // self .config .pretraining_tp , dim = 2 )
1153
+ o_proj_slices = self .o_proj .weight .split (self .hidden_size // self .config .pretraining_tp , dim = 1 )
1154
+ attn_output = sum ([F .linear (attn_output [i ], o_proj_slices [i ]) for i in range (self .config .pretraining_tp )])
1155
+ else :
1156
+ attn_output = self .o_proj (attn_output )
1157
+
1158
+ return attn_output , attn_weights , past_key_value
1159
+
1160
+
1161
+ class AquilaModelPatcher (DecoderModelPatcher ):
1162
+ def __enter__ (self ):
1163
+ super ().__enter__ ()
1164
+ for layer in self ._model .model .layers :
1165
+ if is_torch_version (">=" , "2.1.0" ):
1166
+ orig_self_attn_fwd = layer .self_attn .forward
1167
+ layer .self_attn .forward = types .MethodType (_aquila_self_attn_sdpa_forward , layer .self_attn )
1168
+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1169
+
1170
+ def __exit__ (self , exc_type , exc_value , traceback ):
1171
+ super ().__exit__ (exc_type , exc_value , traceback )
1172
+ for layer in self ._model .model .layers :
1173
+ if hasattr (layer .self_attn , "_orig_forward" ):
1174
+ layer .self_attn .forward = layer .self_attn ._orig_forward
1175
+
1176
+
1177
+ def _xverse_self_attn_sdpa_forward (
1178
+ self ,
1179
+ hidden_states : torch .Tensor ,
1180
+ attention_mask : Optional [torch .Tensor ] = None ,
1181
+ position_ids : Optional [torch .LongTensor ] = None ,
1182
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
1183
+ output_attentions : bool = False ,
1184
+ use_cache : bool = False ,
1185
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
1186
+ def rotate_half (x ):
1187
+ """Rotates half the hidden dims of the input."""
1188
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
1189
+ x2 = x [..., x .shape [- 1 ] // 2 :]
1190
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
1191
+
1192
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids ):
1193
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
1194
+ cos = cos .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
1195
+ sin = sin .squeeze (1 ).squeeze (0 ) # [seq_len, dim]
1196
+ cos = cos [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
1197
+ sin = sin [position_ids ].unsqueeze (1 ) # [bs, 1, seq_len, dim]
1198
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
1199
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
1200
+ return q_embed , k_embed
1201
+
1202
+ if output_attentions :
1203
+ return self ._orig_forward (
1204
+ hidden_states , attention_mask , position_ids , past_key_value , output_attentions , use_cache
1205
+ )
1206
+ bsz , q_len , _ = hidden_states .size ()
1207
+
1208
+ query_states = self .q_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1209
+ key_states = self .k_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1210
+ value_states = self .v_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1211
+
1212
+ kv_seq_len = key_states .shape [- 2 ]
1213
+ if past_key_value is not None :
1214
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
1215
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
1216
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
1217
+ # [bsz, nh, t, hd]
1218
+
1219
+ if past_key_value is not None :
1220
+ # reuse k, v, self_attention
1221
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
1222
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
1223
+
1224
+ past_key_value = (key_states , value_states ) if use_cache else None
1225
+
1226
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1227
+ query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
1228
+ )
1229
+ attn_weights = None
1230
+
1231
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1232
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
1233
+
1234
+ attn_output = self .o_proj (attn_output )
1235
+
1236
+ return attn_output , attn_weights , past_key_value
1237
+
1238
+
1239
+ def _internlm_self_attn_sdpa_forward (
1240
+ self ,
1241
+ hidden_states : torch .Tensor ,
1242
+ attention_mask : Optional [torch .Tensor ] = None ,
1243
+ position_ids : Optional [torch .LongTensor ] = None ,
1244
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
1245
+ output_attentions : bool = False ,
1246
+ use_cache : bool = False ,
1247
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
1248
+ def rotate_half (x ):
1249
+ """Rotates half the hidden dims of the input."""
1250
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
1251
+ x2 = x [..., x .shape [- 1 ] // 2 :]
1252
+ return torch .cat ((- x2 , x1 ), dim = - 1 )
1253
+
1254
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids ):
1255
+ cos = cos [position_ids ].unsqueeze (1 )
1256
+ sin = sin [position_ids ].unsqueeze (1 )
1257
+ q_embed = (q * cos ) + (rotate_half (q ) * sin )
1258
+ k_embed = (k * cos ) + (rotate_half (k ) * sin )
1259
+ return q_embed , k_embed
1260
+
1261
+ if output_attentions :
1262
+ return self ._orig_forward (
1263
+ hidden_states , attention_mask , position_ids , past_key_value , output_attentions , use_cache
1264
+ )
1265
+
1266
+ bsz , q_len , _ = hidden_states .size ()
1267
+ query_states = self .q_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1268
+ key_states = self .k_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1269
+ value_states = self .v_proj (hidden_states ).view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1270
+ kv_seq_len = key_states .shape [- 2 ]
1271
+ if past_key_value is not None :
1272
+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
1273
+
1274
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
1275
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
1276
+
1277
+ if past_key_value is not None :
1278
+ # reuse k, v, self_attention
1279
+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
1280
+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
1281
+
1282
+ past_key_value = (key_states , value_states ) if use_cache else None
1283
+
1284
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1285
+ query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
1286
+ )
1287
+ attn_weights = None
1288
+
1289
+ attn_output = attn_output .transpose (1 , 2 )
1290
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
1291
+
1292
+ attn_output = self .o_proj (attn_output )
1293
+ return attn_output , attn_weights , past_key_value
1294
+
1295
+
1296
+ class XverseModelPatcher (DecoderModelPatcher ):
1297
+ def __enter__ (self ):
1298
+ super ().__enter__ ()
1299
+ for layer in self ._model .model .layers :
1300
+ if is_torch_version (">=" , "2.1.0" ):
1301
+ orig_self_attn_fwd = layer .self_attn .forward
1302
+ layer .self_attn .forward = types .MethodType (_xverse_self_attn_sdpa_forward , layer .self_attn )
1303
+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1304
+
1305
+ def __exit__ (self , exc_type , exc_value , traceback ):
1306
+ super ().__exit__ (exc_type , exc_value , traceback )
1307
+ for layer in self ._model .model .layers :
1308
+ if hasattr (layer .self_attn , "_orig_forward" ):
1309
+ layer .self_attn .forward = layer .self_attn ._orig_forward
1310
+
1311
+
1312
+ class InternLMModelPatcher (DecoderModelPatcher ):
1313
+ def __enter__ (self ):
1314
+ super ().__enter__ ()
1315
+ for layer in self ._model .model .layers :
1316
+ if is_torch_version (">=" , "2.1.0" ):
1317
+ orig_self_attn_fwd = layer .self_attn .forward
1318
+ layer .self_attn .forward = types .MethodType (_internlm_self_attn_sdpa_forward , layer .self_attn )
1319
+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1320
+
1321
+ def __exit__ (self , exc_type , exc_value , traceback ):
1322
+ super ().__exit__ (exc_type , exc_value , traceback )
1323
+ for layer in self ._model .model .layers :
1324
+ if hasattr (layer .self_attn , "_orig_forward" ):
1325
+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments