@@ -136,7 +136,7 @@ def __init__(
136
136
self ._original_model = self .model .clone () # keep original model for serialization
137
137
self ._pkv_precision = Type .f32
138
138
self .next_beam_idx = None
139
- self .past_len = 0
139
+ self ._past_length = 0
140
140
self .update_pkv_precision ()
141
141
if self .is_dynamic :
142
142
self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -386,12 +386,6 @@ def prepare_inputs(
386
386
if not self .stateful :
387
387
if past_key_values is not None :
388
388
if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
389
- seq_len_dim = - 2
390
- if self .config .model_type == "chatglm" :
391
- seq_len_dim = 0
392
- elif self .config .model_type == "qwen" :
393
- seq_len_dim = 1
394
- self .past_len = past_key_values [0 ][1 ].shape [seq_len_dim ]
395
389
if self ._pkv_precision == Type .bf16 :
396
390
# numpy does not support bf16, pretending f16, should change to bf16
397
391
past_key_values = tuple (
@@ -404,15 +398,13 @@ def prepare_inputs(
404
398
past_key_values = tuple (
405
399
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
406
400
)
407
- else :
408
- self .past_len = past_key_values [0 ].shape [- 2 ]
409
401
410
402
# Add the past_key_values to the decoder inputs
411
403
inputs = dict (zip (self .key_value_input_names , past_key_values ))
412
404
413
405
# Create empty past_key_values for decoder_with_past first generation step
414
406
elif self .use_cache :
415
- self . past_len = 0
407
+ past_len = 0
416
408
for input_name in self .key_value_input_names :
417
409
model_inputs = self .model .input (input_name )
418
410
shape = model_inputs .get_partial_shape ()
@@ -435,7 +427,8 @@ def prepare_inputs(
435
427
# Set initial value for the next beam_idx input that will be used at the current iteration
436
428
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
437
429
self .next_beam_idx = np .arange (batch_size , dtype = int )
438
- self .past_len = 0
430
+ self ._past_length = 0
431
+ past_len = self ._get_past_length (past_key_values )
439
432
440
433
inputs ["input_ids" ] = np .array (input_ids )
441
434
# Add the attention_mask inputs when needed
@@ -444,7 +437,7 @@ def prepare_inputs(
444
437
attention_mask = np .array (attention_mask )
445
438
else :
446
439
attention_mask = np .ones (
447
- (input_ids .shape [0 ], input_ids .shape [1 ] + self . past_len ), dtype = inputs ["input_ids" ].dtype
440
+ (input_ids .shape [0 ], input_ids .shape [1 ] + past_len ), dtype = inputs ["input_ids" ].dtype
448
441
)
449
442
450
443
if "attention_mask" in self .input_names :
@@ -457,7 +450,7 @@ def prepare_inputs(
457
450
position_ids = np .cumsum (attention_mask , axis = 1 ) - 1
458
451
position_ids [attention_mask == 0 ] = 1
459
452
if past_key_values :
460
- position_ids = np .expand_dims (position_ids [:, - 1 ], axis = - 1 )
453
+ position_ids = np .expand_dims (position_ids [:, - input_ids . shape [ 1 ] : ], axis = - 1 )
461
454
462
455
inputs ["position_ids" ] = position_ids
463
456
@@ -495,7 +488,7 @@ def forward(
495
488
# the first condition at the function beginning above.
496
489
# It should be something that is not None and it should be True when converted to Boolean.
497
490
past_key_values = ((),)
498
- self .past_len += input_ids .shape [1 ]
491
+ self ._past_length += input_ids .shape [1 ]
499
492
500
493
if not self .stateful :
501
494
if self .use_cache :
@@ -506,10 +499,8 @@ def forward(
506
499
past_key_values = tuple (
507
500
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
508
501
)
509
- self .past_len += input_ids .shape [1 ]
510
502
else :
511
503
past_key_values = None
512
- self .past_len = 0
513
504
514
505
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
515
506
@@ -520,16 +511,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
520
511
use_cache = kwargs .get ("use_cache" , None )
521
512
522
513
if past_key_values is not None :
514
+ past_len = self ._get_past_length (past_key_values )
523
515
# Keep only the unprocessed tokens:
524
516
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
525
517
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
526
518
# input)
527
519
if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
528
- input_ids = input_ids [:, - (attention_mask .shape [1 ] - self . past_len ) :]
520
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_len ) :]
529
521
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
530
522
# input_ids based on the past_length.
531
- elif self . past_len < input_ids .shape [1 ]:
532
- input_ids = input_ids [:, self . past_len :]
523
+ elif past_len < input_ids .shape [1 ]:
524
+ input_ids = input_ids [:, past_len :]
533
525
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
534
526
position_ids = kwargs .get ("position_ids" , None )
535
527
if attention_mask is not None and position_ids is None and "position_ids" in self .input_names :
@@ -547,6 +539,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
547
539
"attention_mask" : attention_mask ,
548
540
}
549
541
542
+ def _get_past_length (self , past_key_values = None ):
543
+ if past_key_values is None :
544
+ return 0
545
+ if self .stateful :
546
+ return self ._past_length
547
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
548
+ return past_key_values [0 ].shape [- 2 ]
549
+ seq_length_dim = - 2
550
+ if self .config .model_type == "chatglm" :
551
+ seq_length_dim = 0
552
+ elif self .config .model_type == "qwen" :
553
+ seq_length_dim = 1
554
+ # input is tuple of pairs
555
+ if isinstance (past_key_values [0 ], (tuple , list )):
556
+ return past_key_values [0 ][1 ].shape [seq_length_dim ]
557
+ # past key values comes after flattening
558
+ return past_key_values [1 ].shape [seq_length_dim ]
559
+
550
560
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
551
561
def _reorder_cache (
552
562
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
0 commit comments