@@ -229,7 +229,7 @@ def _llama_model_forward(
229
229
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
230
230
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
231
231
query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
232
- max_input_lens = input_lens .max (). item ()
232
+ max_input_lens = input_lens .max ()
233
233
234
234
if past_key_values_length == 0 and past_key_values is not None :
235
235
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -357,7 +357,7 @@ def _falcon_model_forward(
357
357
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
358
358
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
359
359
query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
360
- max_input_lens = input_lens .max (). item ()
360
+ max_input_lens = input_lens .max ()
361
361
362
362
if past_key_values_length == 0 and past_key_values is not None :
363
363
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -499,7 +499,7 @@ def _gpt2_model_forward(
499
499
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
500
500
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
501
501
query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
502
- max_input_lens = input_lens .max (). item ()
502
+ max_input_lens = input_lens .max ()
503
503
504
504
if past_length == 0 and past_key_values is not None :
505
505
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -635,7 +635,7 @@ def _qwen2_model_forward(
635
635
input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
636
636
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
637
637
query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
638
- max_input_lens = input_lens .max (). item ()
638
+ max_input_lens = input_lens .max ()
639
639
640
640
if past_key_values_length == 0 and past_key_values is not None :
641
641
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -752,11 +752,11 @@ def attention_interface(
752
752
if past_key_value is None :
753
753
n_rep = query .shape [1 ] // key .shape [1 ]
754
754
attn_output = torch .nn .functional .scaled_dot_product_attention (
755
- query .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
756
- key .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , key .shape [- 1 ])
755
+ query .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
756
+ key .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , key .shape [- 1 ])
757
757
.transpose (1 , 2 )
758
758
.repeat_interleave (n_rep , 1 ),
759
- value .reshape (input_lens .shape [0 ], input_lens .max (). item () , - 1 , value .shape [- 1 ])
759
+ value .reshape (input_lens .shape [0 ], input_lens .max (), - 1 , value .shape [- 1 ])
760
760
.transpose (1 , 2 )
761
761
.repeat_interleave (n_rep , 1 ),
762
762
attn_mask = attention_mask ,
@@ -883,13 +883,11 @@ def __init__(self, module, device, config) -> None:
883
883
self .q_slice = self .q_proj .weight .shape [0 ]
884
884
self .k_slice = self .q_slice + self .k_proj .weight .shape [0 ]
885
885
self .v_slice = self .k_slice + self .v_proj .weight .shape [0 ]
886
- if self . module_device . type == "cpu" :
887
- if module . o_proj . __class__ . __name__ not in [ "LinearAllreduce" ] :
886
+ if not config . compile and module . o_proj . __class__ . __name__ not in [ "LinearAllreduce" ] :
887
+ if self . module_device . type == "cpu" :
888
888
self .mha_linear_add = LinearAdd (module .o_proj )
889
-
890
889
elif self .module_device .type == "xpu" :
891
- if module .o_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
892
- self .mha_linear_add = XPULinearAdd (module .o_proj )
890
+ self .mha_linear_add = XPULinearAdd (module .o_proj )
893
891
894
892
def qkv_gemm (self , hidden_states ):
895
893
if hasattr (self , "concat_qkv" ):
@@ -932,7 +930,7 @@ def __init__(self, module, device, config) -> None:
932
930
self .num_key_value_heads = config .num_key_value_heads
933
931
super ().__init__ (module , device , config )
934
932
_setattr_from_module (self , module )
935
- if getattr (config , "quantization_config" , None ) is None :
933
+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
936
934
self .c_attn_linear = nn .Linear (self .c_attn .weight .shape [0 ], self .c_attn .weight .shape [1 ])
937
935
self .c_attn_linear .weight = nn .Parameter (self .c_attn .weight .t ())
938
936
self .c_attn_linear .bias = self .c_attn .bias
@@ -976,7 +974,7 @@ def __init__(self, module, device, config) -> None:
976
974
_setattr_from_module (self , module )
977
975
self .config = config
978
976
self .module_device = device
979
- if getattr (config , "quantization_config" , None ) is None :
977
+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
980
978
if self .module_device .type == "cpu" :
981
979
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
982
980
if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
@@ -1009,7 +1007,7 @@ def __init__(self, module, device, config) -> None:
1009
1007
_setattr_from_module (self , module )
1010
1008
self .config = config
1011
1009
self .module_device = device
1012
- if getattr (config , "quantization_config" , None ) is None :
1010
+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
1013
1011
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
1014
1012
if self .module_device .type == "cpu" :
1015
1013
self .linear_gelu = LinearGelu (module .dense_h_to_4h )
@@ -1049,7 +1047,7 @@ def __init__(self, module, device, config) -> None:
1049
1047
self .config = config
1050
1048
self .module_device = device
1051
1049
1052
- if getattr (config , "quantization_config" , None ) is None :
1050
+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
1053
1051
self .c_fc_linear = nn .Linear (self .c_fc .weight .shape [0 ], self .c_fc .weight .shape [1 ])
1054
1052
self .c_fc_linear .weight = nn .Parameter (self .c_fc .weight .t ())
1055
1053
self .c_fc_linear .bias = self .c_fc .bias
@@ -1058,11 +1056,8 @@ def __init__(self, module, device, config) -> None:
1058
1056
self .c_proj_linear .bias = self .c_proj .bias
1059
1057
if self .module_device .type == "cpu" :
1060
1058
self .linear_new_gelu = LinearNewGelu (self .c_fc_linear )
1061
-
1062
- if self .module_device .type == "cpu" :
1063
1059
if self .c_proj_linear not in ["LinearAllreduce" ]:
1064
1060
self .linear_add = LinearAdd (self .c_proj_linear )
1065
-
1066
1061
elif self .module_device .type == "xpu" :
1067
1062
if self .c_proj_linear not in ["LinearAllreduce" ]:
1068
1063
self .linear_add = XPULinearAdd (self .c_proj_linear )
@@ -1234,7 +1229,7 @@ def __init__(self, module, device, config):
1234
1229
super ().__init__ ()
1235
1230
_setattr_from_module (self , module )
1236
1231
self .module_device = device
1237
- if getattr (config , "quantization_config" , None ) is None :
1232
+ if not config . compile and getattr (config , "quantization_config" , None ) is None :
1238
1233
if self .module_device .type == "cpu" :
1239
1234
self .linear_gelu = LinearGelu (module .dense )
1240
1235
elif self .module_device .type == "xpu" :
0 commit comments