@@ -136,6 +136,7 @@ def __init__(
136
136
self ._original_model = self .model .clone () # keep original model for serialization
137
137
self ._pkv_precision = Type .f32
138
138
self .next_beam_idx = None
139
+ self .past_len = 0
139
140
self .update_pkv_precision ()
140
141
if self .is_dynamic :
141
142
self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -377,19 +378,21 @@ def prepare_inputs(
377
378
position_ids : Optional [torch .LongTensor ] = None ,
378
379
** kwargs ,
379
380
) -> Dict :
380
- if self .use_cache and past_key_values is not None :
381
- input_ids = input_ids [:, - 1 :]
382
381
383
382
batch_size = input_ids .shape [0 ]
384
383
if self .config .model_type == "bloom" :
385
384
batch_size *= self .config .num_attention_heads
386
385
387
386
inputs = {}
388
- past_len = 0
389
387
if not self .stateful :
390
388
if past_key_values is not None :
391
389
if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
392
- past_len = past_key_values [0 ][1 ].shape [- 2 ]
390
+ seq_len_dim = - 2
391
+ if self .config .model_type == "chatglm" :
392
+ seq_len_dim = 0
393
+ elif self .config .model_type == "qwen" :
394
+ seq_len_dim = 1
395
+ self .past_len = past_key_values [0 ][1 ].shape [seq_len_dim ]
393
396
if self ._pkv_precision == Type .bf16 :
394
397
# numpy does not support bf16, pretending f16, should change to bf16
395
398
past_key_values = tuple (
@@ -403,13 +406,14 @@ def prepare_inputs(
403
406
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
404
407
)
405
408
else :
406
- past_len = past_key_values [0 ].shape [- 2 ]
409
+ self . past_len = past_key_values [0 ].shape [- 2 ]
407
410
408
411
# Add the past_key_values to the decoder inputs
409
412
inputs = dict (zip (self .key_value_input_names , past_key_values ))
410
413
411
414
# Create empty past_key_values for decoder_with_past first generation step
412
415
elif self .use_cache :
416
+ self .past_len = 0
413
417
for input_name in self .key_value_input_names :
414
418
model_inputs = self .model .input (input_name )
415
419
shape = model_inputs .get_partial_shape ()
@@ -432,6 +436,7 @@ def prepare_inputs(
432
436
# Set initial value for the next beam_idx input that will be used at the current iteration
433
437
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
434
438
self .next_beam_idx = np .arange (batch_size , dtype = int )
439
+ self .past_len = 0
435
440
436
441
inputs ["input_ids" ] = np .array (input_ids )
437
442
# Add the attention_mask inputs when needed
@@ -440,7 +445,7 @@ def prepare_inputs(
440
445
attention_mask = np .array (attention_mask )
441
446
else :
442
447
attention_mask = np .ones (
443
- (input_ids .shape [0 ], input_ids .shape [1 ] + past_len ), dtype = inputs ["input_ids" ].dtype
448
+ (input_ids .shape [0 ], input_ids .shape [1 ] + self . past_len ), dtype = inputs ["input_ids" ].dtype
444
449
)
445
450
446
451
if "attention_mask" in self .input_names :
@@ -491,6 +496,7 @@ def forward(
491
496
# the first condition at the function beginning above.
492
497
# It should be something that is not None and it should be True when converted to Boolean.
493
498
past_key_values = ((),)
499
+ self .past_len += input_ids .shape [1 ]
494
500
495
501
if not self .stateful :
496
502
if self .use_cache :
@@ -501,24 +507,38 @@ def forward(
501
507
past_key_values = tuple (
502
508
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
503
509
)
510
+ self .past_len += input_ids .shape [1 ]
504
511
else :
505
512
past_key_values = None
513
+ self .past_len = 0
506
514
507
515
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
508
516
509
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel .prepare_inputs_for_generation
517
+ # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM .prepare_inputs_for_generation
510
518
def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
511
519
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
512
520
attention_mask = kwargs .get ("attention_mask" , None )
513
521
use_cache = kwargs .get ("use_cache" , None )
514
522
523
+ if past_key_values is not None :
524
+ # Keep only the unprocessed tokens:
525
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
526
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
527
+ # input)
528
+ if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
529
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - self .past_len ) :]
530
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
531
+ # input_ids based on the past_length.
532
+ elif self .past_len < input_ids .shape [1 ]:
533
+ input_ids = input_ids [:, self .past_len :]
534
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
515
535
position_ids = kwargs .get ("position_ids" , None )
516
- if attention_mask is not None and position_ids is None :
536
+ if attention_mask is not None and position_ids is None and "position_ids" in self . input_names :
517
537
# create position_ids on the fly for batch generation
518
538
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
519
539
position_ids .masked_fill_ (attention_mask == 0 , 1 )
520
540
if past_key_values :
521
- position_ids = position_ids [:, - 1 ]. unsqueeze ( - 1 )
541
+ position_ids = position_ids [:, - input_ids . shape [ 1 ]:]
522
542
523
543
return {
524
544
"input_ids" : input_ids ,
@@ -594,10 +614,6 @@ def _from_pretrained(
594
614
model_type = config .model_type .replace ("_" , "-" )
595
615
if model_type == "bloom" :
596
616
init_cls = OVBloomForCausalLM
597
- elif model_type == "mpt" :
598
- init_cls = OVMPTForCausalLM
599
- elif model_type == "opt" :
600
- init_cls = OVOPTForCausalLM
601
617
elif model_type == "gpt-bigcode" :
602
618
init_cls = OVGPTBigCodeForCausalLM
603
619
else :
@@ -651,22 +667,13 @@ def _from_pretrained(
651
667
class OVBloomForCausalLM (OVModelForCausalLM ):
652
668
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
653
669
def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
654
- attention_mask = kwargs .get ("attention_mask" , None )
655
- use_cache = kwargs .get ("use_cache" , None )
656
-
657
670
# only last token for input_ids if past is not None
658
671
if past_key_values and not self .stateful :
659
672
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
660
673
if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
661
674
past_key_values = self ._convert_to_bloom_cache (past_key_values )
662
-
663
- return {
664
- "input_ids" : input_ids ,
665
- "past_key_values" : past_key_values ,
666
- "use_cache" : use_cache ,
667
- "position_ids" : None ,
668
- "attention_mask" : attention_mask ,
669
- }
675
+
676
+ return super ().prepare_inputs_for_generation (self , input_ids , past_key_values = past_key_values , ** kwargs )
670
677
671
678
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
672
679
def _reorder_cache (
@@ -733,36 +740,6 @@ def _convert_to_standard_cache(
733
740
)
734
741
735
742
736
- class OVOPTForCausalLM (OVModelForCausalLM ):
737
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
738
- def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
739
- attention_mask = kwargs .get ("attention_mask" , None )
740
- use_cache = kwargs .get ("use_cache" , None )
741
-
742
- return {
743
- "input_ids" : input_ids ,
744
- "past_key_values" : past_key_values ,
745
- "use_cache" : use_cache ,
746
- "position_ids" : None ,
747
- "attention_mask" : attention_mask ,
748
- }
749
-
750
-
751
- class OVMPTForCausalLM (OVModelForCausalLM ):
752
- # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
753
- def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
754
- attention_mask = kwargs .get ("attention_mask" , None )
755
- use_cache = kwargs .get ("use_cache" , None )
756
-
757
- return {
758
- "input_ids" : input_ids ,
759
- "past_key_values" : past_key_values ,
760
- "use_cache" : use_cache ,
761
- "position_ids" : None ,
762
- "attention_mask" : attention_mask ,
763
- }
764
-
765
-
766
743
class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
767
744
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
768
745
def _reorder_cache (
0 commit comments