@@ -167,24 +167,27 @@ def __init__(
167
167
)
168
168
self .multi_query_group_num = normalized_config .multi_query_group_num
169
169
self .head_dim = normalized_config .kv_channels
170
+ self .standart_cache_layout = hasattr (normalized_config , "rope_ratio" )
170
171
171
172
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
172
- past_key_shape = (
173
- self .sequence_length ,
174
- self .batch_size ,
175
- self .multi_query_group_num ,
176
- self .head_dim ,
177
- )
178
- past_value_shape = (
179
- self .sequence_length ,
180
- self .batch_size ,
181
- self .multi_query_group_num ,
182
- self .head_dim ,
183
- )
173
+ if not self .standart_cache_layout :
174
+ pkv_shape = (
175
+ self .sequence_length ,
176
+ self .batch_size ,
177
+ self .multi_query_group_num ,
178
+ self .head_dim ,
179
+ )
180
+ else :
181
+ pkv_shape = (
182
+ self .batch_size ,
183
+ self .multi_query_group_num ,
184
+ self .sequence_length ,
185
+ self .head_dim ,
186
+ )
184
187
return [
185
188
(
186
- self .random_float_tensor (past_key_shape , framework = framework , dtype = float_dtype ),
187
- self .random_float_tensor (past_value_shape , framework = framework , dtype = float_dtype ),
189
+ self .random_float_tensor (pkv_shape , framework = framework , dtype = float_dtype ),
190
+ self .random_float_tensor (pkv_shape , framework = framework , dtype = float_dtype ),
188
191
)
189
192
for _ in range (self .num_layers )
190
193
]
@@ -229,7 +232,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
229
232
and "attention_mask" in dummy_inputs
230
233
):
231
234
# Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2
232
- past_present_length = dummy_inputs ["input_ids" ].shape [1 ] + dummy_inputs ["past_key_values" ][0 ][1 ].shape [0 ]
235
+ seq_len_dim = 0 if not hasattr (self ._normalized_config , "rope_ratio" ) else - 2
236
+ past_present_length = (
237
+ dummy_inputs ["input_ids" ].shape [1 ] + dummy_inputs ["past_key_values" ][0 ][1 ].shape [seq_len_dim ]
238
+ )
233
239
234
240
dummy_inputs ["attention_mask" ] = DummyInputGenerator .pad_input_on_dim (
235
241
dummy_inputs ["attention_mask" ],
@@ -260,9 +266,18 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
260
266
decoder_sequence_name = "past_sequence_length + present_lenght"
261
267
name = "present"
262
268
269
+ is_v4 = hasattr (self ._normalized_config , "rope_ratio" )
263
270
for i in range (self ._normalized_config .num_layers ):
264
- inputs_or_outputs [f"{ name } .{ i } .key" ] = {1 : "batch_size" , 0 : decoder_sequence_name }
265
- inputs_or_outputs [f"{ name } .{ i } .value" ] = {1 : "batch_size" , 0 : decoder_sequence_name }
271
+ inputs_or_outputs [f"{ name } .{ i } .key" ] = (
272
+ {1 : "batch_size" , 0 : decoder_sequence_name }
273
+ if not is_v4
274
+ else {0 : "batch_size" , 2 : decoder_sequence_name }
275
+ )
276
+ inputs_or_outputs [f"{ name } .{ i } .value" ] = (
277
+ {1 : "batch_size" , 0 : decoder_sequence_name }
278
+ if not is_v4
279
+ else {0 : "batch_size" , 2 : decoder_sequence_name }
280
+ )
266
281
267
282
def patch_model_for_export (
268
283
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
0 commit comments