@@ -120,6 +120,7 @@ def __init__(
120
120
self ._original_model = self .model .clone () # keep original model for serialization
121
121
self ._pkv_precision = Type .f32
122
122
self .next_beam_idx = None
123
+ self ._past_length = 0
123
124
self .update_pkv_precision ()
124
125
if self .is_dynamic :
125
126
self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -356,19 +357,14 @@ def prepare_inputs(
356
357
position_ids : Optional [torch .LongTensor ] = None ,
357
358
** kwargs ,
358
359
) -> Dict :
359
- if self .use_cache and past_key_values is not None :
360
- input_ids = input_ids [:, - 1 :]
361
-
362
360
batch_size = input_ids .shape [0 ]
363
361
if self .config .model_type == "bloom" :
364
362
batch_size *= self .config .num_attention_heads
365
363
366
364
inputs = {}
367
- past_len = 0
368
365
if not self .stateful :
369
366
if past_key_values is not None :
370
367
if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
371
- past_len = past_key_values [0 ][1 ].shape [- 2 ]
372
368
if self ._pkv_precision == Type .bf16 :
373
369
# numpy does not support bf16, pretending f16, should change to bf16
374
370
past_key_values = tuple (
@@ -381,8 +377,6 @@ def prepare_inputs(
381
377
past_key_values = tuple (
382
378
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
383
379
)
384
- else :
385
- past_len = past_key_values [0 ].shape [- 2 ]
386
380
387
381
# Add the past_key_values to the decoder inputs
388
382
inputs = dict (zip (self .key_value_input_names , past_key_values ))
@@ -411,6 +405,8 @@ def prepare_inputs(
411
405
# Set initial value for the next beam_idx input that will be used at the current iteration
412
406
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
413
407
self .next_beam_idx = np .arange (batch_size , dtype = int )
408
+ self ._past_length = 0
409
+ past_len = self ._get_past_length (past_key_values )
414
410
415
411
inputs ["input_ids" ] = np .array (input_ids )
416
412
# Add the attention_mask inputs when needed
@@ -432,7 +428,7 @@ def prepare_inputs(
432
428
position_ids = np .cumsum (attention_mask , axis = 1 ) - 1
433
429
position_ids [attention_mask == 0 ] = 1
434
430
if past_key_values :
435
- position_ids = np . expand_dims ( position_ids [:, - 1 ], axis = - 1 )
431
+ position_ids = position_ids [:, - input_ids . shape [ 1 ] :]
436
432
437
433
inputs ["position_ids" ] = position_ids
438
434
@@ -470,6 +466,7 @@ def forward(
470
466
# the first condition at the function beginning above.
471
467
# It should be something that is not None and it should be True when converted to Boolean.
472
468
past_key_values = ((),)
469
+ self ._past_length += input_ids .shape [1 ]
473
470
474
471
if not self .stateful :
475
472
if self .use_cache :
@@ -485,19 +482,32 @@ def forward(
485
482
486
483
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
487
484
488
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel .prepare_inputs_for_generation
485
+ # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM .prepare_inputs_for_generation
489
486
def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
490
487
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
491
488
attention_mask = kwargs .get ("attention_mask" , None )
492
489
use_cache = kwargs .get ("use_cache" , None )
493
490
491
+ if past_key_values is not None :
492
+ past_len = self ._get_past_length (past_key_values )
493
+ # Keep only the unprocessed tokens:
494
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
495
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
496
+ # input)
497
+ if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
498
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_len ) :]
499
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
500
+ # input_ids based on the past_length.
501
+ elif past_len < input_ids .shape [1 ]:
502
+ input_ids = input_ids [:, past_len :]
503
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
494
504
position_ids = kwargs .get ("position_ids" , None )
495
- if attention_mask is not None and position_ids is None :
505
+ if attention_mask is not None and position_ids is None and "position_ids" in self . input_names :
496
506
# create position_ids on the fly for batch generation
497
507
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
498
508
position_ids .masked_fill_ (attention_mask == 0 , 1 )
499
509
if past_key_values :
500
- position_ids = position_ids [:, - 1 ]. unsqueeze ( - 1 )
510
+ position_ids = position_ids [:, - input_ids . shape [ 1 ] :]
501
511
502
512
return {
503
513
"input_ids" : input_ids ,
@@ -507,6 +517,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
507
517
"attention_mask" : attention_mask ,
508
518
}
509
519
520
+ def _get_past_length (self , past_key_values = None ):
521
+ if past_key_values is None :
522
+ return 0
523
+ if self .stateful :
524
+ return self ._past_length
525
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
526
+ return past_key_values [0 ].shape [- 2 ]
527
+ seq_length_dim = - 2
528
+ if self .config .model_type == "chatglm" :
529
+ seq_length_dim = 0
530
+ elif self .config .model_type == "qwen" :
531
+ seq_length_dim = 1
532
+ # input is tuple of pairs
533
+ if isinstance (past_key_values [0 ], (tuple , list )):
534
+ return past_key_values [0 ][1 ].shape [seq_length_dim ]
535
+ # past key values comes after flattening
536
+ return past_key_values [1 ].shape [seq_length_dim ]
537
+
510
538
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
511
539
def _reorder_cache (
512
540
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
@@ -573,10 +601,6 @@ def _from_pretrained(
573
601
model_type = config .model_type .replace ("_" , "-" )
574
602
if model_type == "bloom" :
575
603
init_cls = OVBloomForCausalLM
576
- elif model_type == "mpt" :
577
- init_cls = OVMPTForCausalLM
578
- elif model_type == "opt" :
579
- init_cls = OVOPTForCausalLM
580
604
elif model_type == "gpt-bigcode" :
581
605
init_cls = OVGPTBigCodeForCausalLM
582
606
else :
@@ -630,22 +654,12 @@ def _from_pretrained(
630
654
class OVBloomForCausalLM (OVModelForCausalLM ):
631
655
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
632
656
def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
633
- attention_mask = kwargs .get ("attention_mask" , None )
634
- use_cache = kwargs .get ("use_cache" , None )
635
-
636
657
# only last token for input_ids if past is not None
637
658
if past_key_values and not self .stateful :
638
659
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
639
660
if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
640
661
past_key_values = self ._convert_to_bloom_cache (past_key_values )
641
-
642
- return {
643
- "input_ids" : input_ids ,
644
- "past_key_values" : past_key_values ,
645
- "use_cache" : use_cache ,
646
- "position_ids" : None ,
647
- "attention_mask" : attention_mask ,
648
- }
662
+ return super ().prepare_inputs_for_generation (input_ids , past_key_values = past_key_values , ** kwargs )
649
663
650
664
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
651
665
def _reorder_cache (
@@ -712,36 +726,6 @@ def _convert_to_standard_cache(
712
726
)
713
727
714
728
715
- class OVOPTForCausalLM (OVModelForCausalLM ):
716
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
717
- def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
718
- attention_mask = kwargs .get ("attention_mask" , None )
719
- use_cache = kwargs .get ("use_cache" , None )
720
-
721
- return {
722
- "input_ids" : input_ids ,
723
- "past_key_values" : past_key_values ,
724
- "use_cache" : use_cache ,
725
- "position_ids" : None ,
726
- "attention_mask" : attention_mask ,
727
- }
728
-
729
-
730
- class OVMPTForCausalLM (OVModelForCausalLM ):
731
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
732
- def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
733
- attention_mask = kwargs .get ("attention_mask" , None )
734
- use_cache = kwargs .get ("use_cache" , None )
735
-
736
- return {
737
- "input_ids" : input_ids ,
738
- "past_key_values" : past_key_values ,
739
- "use_cache" : use_cache ,
740
- "position_ids" : None ,
741
- "attention_mask" : attention_mask ,
742
- }
743
-
744
-
745
729
class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
746
730
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
747
731
def _reorder_cache (
0 commit comments