@@ -1099,7 +1099,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1099
1099
)
1100
1100
bsz , q_len , _ = hidden_states .size ()
1101
1101
1102
- if self .config .pretraining_tp > 1 :
1102
+ if hasattr ( self . config , "pretraining_tp" ) and self .config .pretraining_tp > 1 :
1103
1103
key_value_slicing = (self .num_key_value_heads * self .head_dim ) // self .config .pretraining_tp
1104
1104
query_slices = self .q_proj .weight .split ((self .num_heads * self .head_dim ) // self .config .pretraining_tp , dim = 0 )
1105
1105
key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
@@ -1120,8 +1120,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1120
1120
value_states = self .v_proj (hidden_states )
1121
1121
1122
1122
query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
1123
- key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
1124
- value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
1123
+ key_states = key_states .view (
1124
+ bsz , q_len , getattr (self , "num_key_value_heads" , self .num_heads ), self .head_dim
1125
+ ).transpose (1 , 2 )
1126
+ value_states = value_states .view (
1127
+ bsz , q_len , getattr (self , "num_key_value_heads" , self .num_heads ), self .head_dim
1128
+ ).transpose (1 , 2 )
1125
1129
1126
1130
kv_seq_len = key_states .shape [- 2 ]
1127
1131
if past_key_value is not None :
@@ -1136,9 +1140,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1136
1140
1137
1141
past_key_value = (key_states , value_states ) if use_cache else None
1138
1142
1139
- # repeat k/v heads if n_kv_heads < n_heads
1140
- key_states = repeat_kv (key_states , self .num_key_value_groups )
1141
- value_states = repeat_kv (value_states , self .num_key_value_groups )
1143
+ if hasattr (self , "num_key_value_groups" ):
1144
+ # repeat k/v heads if n_kv_heads < n_heads
1145
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
1146
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
1142
1147
1143
1148
attn_output = torch .nn .functional .scaled_dot_product_attention (
1144
1149
query_states , key_states , value_states , attention_mask , scale = (1 / math .sqrt (self .head_dim ))
@@ -1148,7 +1153,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
1148
1153
attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1149
1154
attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
1150
1155
1151
- if self .config .pretraining_tp > 1 :
1156
+ if hasattr ( self . config , "pretraining_tp" ) and self .config .pretraining_tp > 1 :
1152
1157
attn_output = attn_output .split (self .hidden_size // self .config .pretraining_tp , dim = 2 )
1153
1158
o_proj_slices = self .o_proj .weight .split (self .hidden_size // self .config .pretraining_tp , dim = 1 )
1154
1159
attn_output = sum ([F .linear (attn_output [i ], o_proj_slices [i ]) for i in range (self .config .pretraining_tp )])
0 commit comments