17
17
import warnings
18
18
from pathlib import Path
19
19
from tempfile import TemporaryDirectory
20
- from typing import Dict , Optional , Tuple , Union
20
+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
21
21
22
22
import numpy as np
23
23
import openvino
28
28
from transformers import AutoModelForCausalLM , PretrainedConfig
29
29
from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
30
30
from transformers .generation import GenerationMixin
31
+ from transformers .generation .configuration_utils import GenerationConfig , GenerationMode
32
+ from transformers .generation .logits_process import LogitsProcessorList
33
+ from transformers .generation .stopping_criteria import StoppingCriteriaList
34
+ from transformers .generation .utils import GenerateOutput
31
35
from transformers .modeling_outputs import CausalLMOutputWithPast
32
36
33
37
from optimum .utils .normalized_config import NormalizedConfigManager
41
45
from .utils import ONNX_WEIGHTS_NAME , OV_XML_FILE_NAME , STR_TO_OV_TYPE
42
46
43
47
48
+ if TYPE_CHECKING :
49
+ from transformers .modeling_utils import PreTrainedModel
50
+ from transformers .streamers import BaseStreamer
51
+
52
+
44
53
logger = logging .getLogger (__name__ )
45
54
46
55
core = Core ()
@@ -122,6 +131,8 @@ def __init__(
122
131
self ._pkv_precision = Type .f32
123
132
self .next_beam_idx = None
124
133
self ._past_length = 0
134
+ self ._first_iter_beam_search = False
135
+ self ._second_iter_beam_search = False
125
136
self .update_pkv_precision ()
126
137
if self .is_dynamic :
127
138
self .model = self ._reshape (self .model , - 1 , - 1 )
@@ -375,7 +386,11 @@ def prepare_inputs(
375
386
inputs = {}
376
387
if not self .stateful :
377
388
if past_key_values is not None :
378
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
389
+ if (
390
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391
+ or self .config .model_type == "falcon"
392
+ and self .config .new_decoder_architecture
393
+ ):
379
394
if self ._pkv_precision == Type .bf16 :
380
395
# numpy does not support bf16, pretending f16, should change to bf16
381
396
past_key_values = tuple (
@@ -418,7 +433,6 @@ def prepare_inputs(
418
433
self .next_beam_idx = np .arange (batch_size , dtype = int )
419
434
self ._past_length = 0
420
435
past_len = self ._get_past_length (past_key_values )
421
-
422
436
inputs ["input_ids" ] = np .array (input_ids )
423
437
# Add the attention_mask inputs when needed
424
438
if "attention_mask" in self .input_names or "position_ids" in self .input_names :
@@ -468,6 +482,8 @@ def forward(
468
482
** kwargs ,
469
483
)
470
484
485
+ if self ._first_iter_beam_search :
486
+ inputs , duplication_indices = self ._deduplicate_inputs (inputs )
471
487
# Run inference
472
488
self .request .start_async (inputs , share_inputs = True )
473
489
self .request .wait ()
@@ -483,14 +499,22 @@ def forward(
483
499
if self .use_cache :
484
500
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
485
501
past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
486
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
502
+ if (
503
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504
+ or self .config .model_type == "falcon"
505
+ and self .config .new_decoder_architecture
506
+ ):
487
507
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
488
508
past_key_values = tuple (
489
509
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
490
510
)
491
511
else :
492
512
past_key_values = None
493
513
514
+ if self ._first_iter_beam_search :
515
+ logits , past_key_values = self ._expand_outputs_for_generation (duplication_indices , logits , past_key_values )
516
+ self ._first_iter_beam_search = False
517
+
494
518
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
495
519
496
520
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
@@ -520,20 +544,124 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
520
544
if past_key_values :
521
545
position_ids = position_ids [:, - input_ids .shape [1 ] :]
522
546
523
- return {
547
+ model_inputs = {
524
548
"input_ids" : input_ids ,
525
549
"past_key_values" : past_key_values ,
526
550
"use_cache" : use_cache ,
527
551
"position_ids" : position_ids ,
528
552
"attention_mask" : attention_mask ,
529
553
}
530
554
555
+ return model_inputs
556
+
557
+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
558
+ batch_size = logits .shape [0 ]
559
+ if indicies .shape [0 ] != 1 :
560
+ logits = logits [indicies ]
561
+ if past_key_values and not self .stateful :
562
+ if (
563
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564
+ or self .config .model_type == "falcon"
565
+ and self .config .new_decoder_architecture
566
+ ):
567
+ past_key_values = tuple (
568
+ tuple (
569
+ past_state [indicies ]
570
+ if not self .config .model_type == "chatglm"
571
+ else past_state [:, indicies , ...]
572
+ for past_state in layer_past
573
+ )
574
+ for layer_past in past_key_values
575
+ )
576
+ else :
577
+ past_key_values = tuple ([past_state [indicies ] for past_state in past_key_values ])
578
+ if self .stateful :
579
+ self .next_beam_idx = (
580
+ self .next_beam_idx [indicies ]
581
+ if self .next_beam_idx is not None
582
+ else np .arange (batch_size , dtype = int )[indicies ]
583
+ )
584
+ self ._second_iter_beam_search = True
585
+ return logits , past_key_values
586
+
587
+ def _deduplicate_inputs (self , model_inputs : Dict ):
588
+ input_ids = model_inputs ["input_ids" ]
589
+ upd_model_inputs = {}
590
+ unique_input_ids , indicies , reverse_indicies = np .unique (
591
+ input_ids , axis = 0 , return_index = True , return_inverse = True
592
+ )
593
+ for input_name , input_tensor in model_inputs .items ():
594
+ if input_name not in ["input_ids" , "beam_idx" ]:
595
+ if not isinstance (input_tensor , Tensor ):
596
+ upd_model_inputs [input_name ] = input_tensor [indicies ]
597
+ else :
598
+ shape = input_tensor .shape
599
+ dtype = input_tensor .element_type
600
+ upd_batch_size = indicies .shape [0 ]
601
+ if self .config .model_type == "bloom" :
602
+ upd_batch_size *= self .config .num_attention_heads
603
+ shape [0 if not self .config .model_type == "chatglm" else 1 ] = upd_batch_size
604
+ upd_model_inputs [input_name ] = Tensor (dtype , shape )
605
+ upd_model_inputs ["input_ids" ] = unique_input_ids
606
+ if "beam_idx" in model_inputs :
607
+ beam_range = (
608
+ unique_input_ids .shape [0 ]
609
+ if self .config .model_type != "bloom"
610
+ else unique_input_ids .shape [0 ] * self .config .num_attention_heads
611
+ )
612
+ beam_idx = np .arange (beam_range , dtype = int )
613
+ upd_model_inputs ["beam_idx" ] = beam_idx
614
+ return upd_model_inputs , reverse_indicies
615
+
616
+ @torch .no_grad ()
617
+ def generate (
618
+ self ,
619
+ inputs : Optional [torch .Tensor ] = None ,
620
+ generation_config : Optional [GenerationConfig ] = None ,
621
+ logits_processor : Optional [LogitsProcessorList ] = None ,
622
+ stopping_criteria : Optional [StoppingCriteriaList ] = None ,
623
+ prefix_allowed_tokens_fn : Optional [Callable [[int , torch .Tensor ], List [int ]]] = None ,
624
+ synced_gpus : Optional [bool ] = None ,
625
+ assistant_model : Optional ["PreTrainedModel" ] = None ,
626
+ streamer : Optional ["BaseStreamer" ] = None ,
627
+ negative_prompt_ids : Optional [torch .Tensor ] = None ,
628
+ negative_prompt_attention_mask : Optional [torch .Tensor ] = None ,
629
+ ** kwargs ,
630
+ ) -> Union [GenerateOutput , torch .LongTensor ]:
631
+ _generation_config , _ = self ._prepare_generation_config (generation_config , ** kwargs )
632
+ generation_mode = _generation_config .get_generation_mode (assistant_model )
633
+
634
+ is_beam_search = generation_mode in [
635
+ GenerationMode .BEAM_SEARCH ,
636
+ GenerationMode .BEAM_SAMPLE ,
637
+ GenerationMode .GROUP_BEAM_SEARCH ,
638
+ GenerationMode .CONSTRAINED_BEAM_SEARCH ,
639
+ ]
640
+ if is_beam_search :
641
+ self ._first_iter_beam_search = True
642
+ result = super ().generate (
643
+ inputs ,
644
+ generation_config ,
645
+ logits_processor ,
646
+ stopping_criteria ,
647
+ prefix_allowed_tokens_fn ,
648
+ synced_gpus ,
649
+ assistant_model ,
650
+ streamer ,
651
+ negative_prompt_ids ,
652
+ negative_prompt_attention_mask ,
653
+ ** kwargs ,
654
+ )
655
+ return result
656
+
531
657
def _get_past_length (self , past_key_values = None ):
532
658
if past_key_values is None :
533
659
return 0
534
660
if self .stateful :
535
661
return self ._past_length
536
- if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
662
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS and not (
663
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
664
+ ):
537
665
return past_key_values [0 ].shape [- 2 ]
538
666
seq_length_dim = - 2
539
667
if self .config .model_type == "chatglm" :
@@ -558,12 +686,20 @@ def _reorder_cache(
558
686
if self .stateful :
559
687
# TODO: Apply it differently based on model type
560
688
# TODO: At least for bloom we need to replicate values for each attention head
561
- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
689
+ self .next_beam_idx = (
690
+ np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
691
+ ) # save beam_idx to be used as an input in the next iteration
692
+ self ._second_iter_beam_search = False
562
693
return past_key_values
563
694
else :
564
- return tuple (
565
- tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past ) for layer_past in past_key_values
566
- )
695
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
696
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
697
+ ):
698
+ return tuple (
699
+ tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past )
700
+ for layer_past in past_key_values
701
+ )
702
+ return tuple (np .take (past_state , beam_idx , 0 ) for past_state in past_key_values )
567
703
568
704
def can_generate (self ):
569
705
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -684,11 +820,12 @@ def _reorder_cache(
684
820
This is required to match `past_key_values` with the correct beam_idx at every generation step.
685
821
"""
686
822
if self .stateful :
687
- beam_idx = np .array (beam_idx )
688
823
batch_size = beam_idx .shape [0 ]
824
+ beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
689
825
indices = np .array (range (batch_size * self .config .num_attention_heads ))
690
826
indices = indices .reshape ([batch_size , self .config .num_attention_heads ])
691
827
self .next_beam_idx = np .take (indices , beam_idx , 0 ).flatten ()
828
+ self ._second_iter_beam_search = False
692
829
return past_key_values
693
830
else :
694
831
standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
@@ -738,14 +875,34 @@ def _convert_to_standard_cache(
738
875
for layer_past in past_key_value
739
876
)
740
877
878
+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
879
+ batch_size = logits .shape [0 ]
880
+ if indicies .shape [0 ] != 1 :
881
+ logits = logits [indicies ]
882
+ if past_key_values and not self .stateful :
883
+ pkv_standard = self ._convert_to_standard_cache (past_key_values , batch_size )
884
+ pkv = tuple (tuple (past_state [indicies ] for past_state in layer_past ) for layer_past in pkv_standard )
885
+ past_key_values = self ._convert_to_bloom_cache (pkv )
886
+
887
+ if self .stateful :
888
+ self .next_beam_idx = (
889
+ self .next_beam_idx [indicies ]
890
+ if self .next_beam_idx is not None
891
+ else np .arange (batch_size , dtype = int )[indicies ]
892
+ )
893
+ self ._second_iter_beam_search = True
894
+ return logits , past_key_values
895
+
741
896
742
897
class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
743
898
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
744
899
def _reorder_cache (
745
900
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
746
901
) -> Tuple [Tuple [torch .Tensor ]]:
747
902
if self .stateful :
748
- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
903
+ # save beam_idx to be used as an input in the next iteration
904
+ self .next_beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
905
+ self ._second_iter_beam_search = False
749
906
return past_key_values
750
907
else :
751
908
return tuple (np .take (layer_past , beam_idx , 0 ) for layer_past in past_key_values )
0 commit comments