@@ -120,7 +120,7 @@ def __init__(
120
120
self ._original_model = self .model .clone () # keep original model for serialization
121
121
self ._pkv_precision = Type .f32
122
122
self .next_beam_idx = None
123
- self .past_len = 0
123
+ self ._past_length = 0
124
124
self .update_pkv_precision ()
125
125
if self .is_dynamic :
126
126
self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -365,12 +365,6 @@ def prepare_inputs(
365
365
if not self .stateful :
366
366
if past_key_values is not None :
367
367
if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
368
- seq_len_dim = - 2
369
- if self .config .model_type == "chatglm" :
370
- seq_len_dim = 0
371
- elif self .config .model_type == "qwen" :
372
- seq_len_dim = 1
373
- self .past_len = past_key_values [0 ][1 ].shape [seq_len_dim ]
374
368
if self ._pkv_precision == Type .bf16 :
375
369
# numpy does not support bf16, pretending f16, should change to bf16
376
370
past_key_values = tuple (
@@ -383,15 +377,13 @@ def prepare_inputs(
383
377
past_key_values = tuple (
384
378
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
385
379
)
386
- else :
387
- self .past_len = past_key_values [0 ].shape [- 2 ]
388
380
389
381
# Add the past_key_values to the decoder inputs
390
382
inputs = dict (zip (self .key_value_input_names , past_key_values ))
391
383
392
384
# Create empty past_key_values for decoder_with_past first generation step
393
385
elif self .use_cache :
394
- self . past_len = 0
386
+ past_len = 0
395
387
for input_name in self .key_value_input_names :
396
388
model_inputs = self .model .input (input_name )
397
389
shape = model_inputs .get_partial_shape ()
@@ -414,7 +406,8 @@ def prepare_inputs(
414
406
# Set initial value for the next beam_idx input that will be used at the current iteration
415
407
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
416
408
self .next_beam_idx = np .arange (batch_size , dtype = int )
417
- self .past_len = 0
409
+ self ._past_length = 0
410
+ past_len = self ._get_past_length (past_key_values )
418
411
419
412
inputs ["input_ids" ] = np .array (input_ids )
420
413
# Add the attention_mask inputs when needed
@@ -423,7 +416,7 @@ def prepare_inputs(
423
416
attention_mask = np .array (attention_mask )
424
417
else :
425
418
attention_mask = np .ones (
426
- (input_ids .shape [0 ], input_ids .shape [1 ] + self . past_len ), dtype = inputs ["input_ids" ].dtype
419
+ (input_ids .shape [0 ], input_ids .shape [1 ] + past_len ), dtype = inputs ["input_ids" ].dtype
427
420
)
428
421
429
422
if "attention_mask" in self .input_names :
@@ -436,7 +429,7 @@ def prepare_inputs(
436
429
position_ids = np .cumsum (attention_mask , axis = 1 ) - 1
437
430
position_ids [attention_mask == 0 ] = 1
438
431
if past_key_values :
439
- position_ids = np .expand_dims (position_ids [:, - 1 ], axis = - 1 )
432
+ position_ids = np .expand_dims (position_ids [:, - input_ids . shape [ 1 ] : ], axis = - 1 )
440
433
441
434
inputs ["position_ids" ] = position_ids
442
435
@@ -474,7 +467,7 @@ def forward(
474
467
# the first condition at the function beginning above.
475
468
# It should be something that is not None and it should be True when converted to Boolean.
476
469
past_key_values = ((),)
477
- self .past_len += input_ids .shape [1 ]
470
+ self ._past_length += input_ids .shape [1 ]
478
471
479
472
if not self .stateful :
480
473
if self .use_cache :
@@ -485,10 +478,8 @@ def forward(
485
478
past_key_values = tuple (
486
479
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
487
480
)
488
- self .past_len += input_ids .shape [1 ]
489
481
else :
490
482
past_key_values = None
491
- self .past_len = 0
492
483
493
484
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
494
485
@@ -499,16 +490,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
499
490
use_cache = kwargs .get ("use_cache" , None )
500
491
501
492
if past_key_values is not None :
493
+ past_len = self ._get_past_length (past_key_values )
502
494
# Keep only the unprocessed tokens:
503
495
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
504
496
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
505
497
# input)
506
498
if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
507
- input_ids = input_ids [:, - (attention_mask .shape [1 ] - self . past_len ) :]
499
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_len ) :]
508
500
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
509
501
# input_ids based on the past_length.
510
- elif self . past_len < input_ids .shape [1 ]:
511
- input_ids = input_ids [:, self . past_len :]
502
+ elif past_len < input_ids .shape [1 ]:
503
+ input_ids = input_ids [:, past_len :]
512
504
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
513
505
position_ids = kwargs .get ("position_ids" , None )
514
506
if attention_mask is not None and position_ids is None and "position_ids" in self .input_names :
@@ -526,6 +518,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
526
518
"attention_mask" : attention_mask ,
527
519
}
528
520
521
+ def _get_past_length (self , past_key_values = None ):
522
+ if past_key_values is None :
523
+ return 0
524
+ if self .stateful :
525
+ return self ._past_length
526
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
527
+ return past_key_values [0 ].shape [- 2 ]
528
+ seq_length_dim = - 2
529
+ if self .config .model_type == "chatglm" :
530
+ seq_length_dim = 0
531
+ elif self .config .model_type == "qwen" :
532
+ seq_length_dim = 1
533
+ # input is tuple of pairs
534
+ if isinstance (past_key_values [0 ], (tuple , list )):
535
+ return past_key_values [0 ][1 ].shape [seq_length_dim ]
536
+ # past key values comes after flattening
537
+ return past_key_values [1 ].shape [seq_length_dim ]
538
+
529
539
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
530
540
def _reorder_cache (
531
541
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
0 commit comments