14
14
import copy
15
15
import logging
16
16
import os
17
- import warnings
18
17
from pathlib import Path
19
18
from tempfile import TemporaryDirectory
20
- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
19
+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
21
20
22
21
import numpy as np
23
22
import openvino
28
27
from transformers import AutoModelForCausalLM , PretrainedConfig
29
28
from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
30
29
from transformers .generation import GenerationMixin
31
- from transformers .generation .beam_search import BeamScorer , ConstrainedBeamSearchScorer
32
30
from transformers .generation .configuration_utils import GenerationConfig , GenerationMode
33
31
from transformers .generation .logits_process import LogitsProcessorList
34
- from transformers .generation .stopping_criteria import (
35
- EosTokenCriteria ,
36
- StoppingCriteriaList ,
37
- validate_stopping_criteria ,
38
- )
39
- from transformers .generation .utils import (
40
- GenerateBeamDecoderOnlyOutput ,
41
- GenerateBeamOutput ,
42
- GenerateOutput ,
43
- _split_model_inputs ,
44
- stack_model_outputs ,
45
- )
32
+ from transformers .generation .stopping_criteria import StoppingCriteriaList
33
+ from transformers .generation .utils import GenerateOutput
46
34
from transformers .modeling_outputs import CausalLMOutputWithPast
47
35
48
36
from optimum .utils .normalized_config import NormalizedConfigManager
@@ -398,7 +386,11 @@ def prepare_inputs(
398
386
inputs = {}
399
387
if not self .stateful :
400
388
if past_key_values is not None :
401
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
389
+ if (
390
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391
+ or self .config .model_type == "falcon"
392
+ and self .config .new_decoder_architecture
393
+ ):
402
394
if self ._pkv_precision == Type .bf16 :
403
395
# numpy does not support bf16, pretending f16, should change to bf16
404
396
past_key_values = tuple (
@@ -491,9 +483,6 @@ def forward(
491
483
position_ids = position_ids ,
492
484
** kwargs ,
493
485
)
494
-
495
- print (inputs ["input_ids" ].shape )
496
-
497
486
# Run inference
498
487
self .request .start_async (inputs , share_inputs = True )
499
488
self .request .wait ()
@@ -509,7 +498,11 @@ def forward(
509
498
if self .use_cache :
510
499
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
511
500
past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
512
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
501
+ if (
502
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
503
+ or self .config .model_type == "falcon"
504
+ and self .config .new_decoder_architecture
505
+ ):
513
506
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
514
507
past_key_values = tuple (
515
508
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
@@ -561,21 +554,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
561
554
return model_inputs
562
555
563
556
def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
557
+ batch_size = logits .shape [0 ]
564
558
if indicies .shape [0 ] != 1 :
565
559
logits = logits [indicies ]
566
560
if past_key_values and not self .stateful :
567
- past_key_values = tuple (
568
- tuple (
569
- past_state [indicies ]
570
- if not self .config .model_type == "chatglm"
571
- else past_state [:, indicies , ...]
572
- for past_state in layer_past
561
+ if (
562
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
563
+ or self .config .model_type == "falcon"
564
+ and self .config .new_decoder_architecture
565
+ ):
566
+ past_key_values = tuple (
567
+ tuple (
568
+ past_state [indicies ]
569
+ if not self .config .model_type == "chatglm"
570
+ else past_state [:, indicies , ...]
571
+ for past_state in layer_past
572
+ )
573
+ for layer_past in past_key_values
573
574
)
574
- for layer_past in past_key_values
575
- )
576
- if self .stateful :
577
- self .next_beam_idx = self .next_beam_idx [indicies ]
578
- self ._second_iter_beam_search = True
575
+ else :
576
+ past_key_values = tuple ([past_state [indicies ] for past_state in past_key_values ])
577
+ if self .stateful :
578
+ self .next_beam_idx = (
579
+ self .next_beam_idx [indicies ]
580
+ if self .next_beam_idx is not None
581
+ else np .arange (batch_size , dtype = int )[indicies ]
582
+ )
583
+ self ._second_iter_beam_search = True
579
584
return logits , past_key_values
580
585
581
586
def _deduplicate_inputs (self , model_inputs : Dict ):
@@ -591,12 +596,19 @@ def _deduplicate_inputs(self, model_inputs: Dict):
591
596
else :
592
597
shape = input_tensor .shape
593
598
dtype = input_tensor .element_type
594
- shape [0 if not self .config .model_type == "chatglm" else 1 ] = indicies .shape [0 ]
599
+ upd_batch_size = indicies .shape [0 ]
600
+ if self .config .model_type == "bloom" :
601
+ upd_batch_size *= self .config .num_attention_heads
602
+ shape [0 if not self .config .model_type == "chatglm" else 1 ] = upd_batch_size
595
603
upd_model_inputs [input_name ] = Tensor (dtype , shape )
596
- print (f"{ input_name } : { upd_model_inputs [input_name ].shape } " )
597
604
upd_model_inputs ["input_ids" ] = unique_input_ids
598
605
if "beam_idx" in model_inputs :
599
- beam_idx = np .arange (unique_input_ids .shape [0 ], dtype = int )
606
+ beam_range = (
607
+ unique_input_ids .shape [0 ]
608
+ if self .config .model_type != "bloom"
609
+ else unique_input_ids .shape [0 ] * self .config .num_attention_heads
610
+ )
611
+ beam_idx = np .arange (beam_range , dtype = int )
600
612
upd_model_inputs ["beam_idx" ] = beam_idx
601
613
return upd_model_inputs , reverse_indicies
602
614
@@ -646,7 +658,9 @@ def _get_past_length(self, past_key_values=None):
646
658
return 0
647
659
if self .stateful :
648
660
return self ._past_length
649
- if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
661
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS and not (
662
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
663
+ ):
650
664
return past_key_values [0 ].shape [- 2 ]
651
665
seq_length_dim = - 2
652
666
if self .config .model_type == "chatglm" :
@@ -677,9 +691,14 @@ def _reorder_cache(
677
691
self ._second_iter_beam_search = False
678
692
return past_key_values
679
693
else :
680
- return tuple (
681
- tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past ) for layer_past in past_key_values
682
- )
694
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
695
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
696
+ ):
697
+ return tuple (
698
+ tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past )
699
+ for layer_past in past_key_values
700
+ )
701
+ return tuple (np .take (past_state , beam_idx , 0 ) for past_state in past_key_values )
683
702
684
703
def can_generate (self ):
685
704
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -800,11 +819,12 @@ def _reorder_cache(
800
819
This is required to match `past_key_values` with the correct beam_idx at every generation step.
801
820
"""
802
821
if self .stateful :
803
- beam_idx = np .array (beam_idx )
804
822
batch_size = beam_idx .shape [0 ]
823
+ beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
805
824
indices = np .array (range (batch_size * self .config .num_attention_heads ))
806
825
indices = indices .reshape ([batch_size , self .config .num_attention_heads ])
807
826
self .next_beam_idx = np .take (indices , beam_idx , 0 ).flatten ()
827
+ self ._second_iter_beam_search = False
808
828
return past_key_values
809
829
else :
810
830
standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
@@ -854,14 +874,34 @@ def _convert_to_standard_cache(
854
874
for layer_past in past_key_value
855
875
)
856
876
877
+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
878
+ batch_size = logits .shape [0 ]
879
+ if indicies .shape [0 ] != 1 :
880
+ logits = logits [indicies ]
881
+ if past_key_values and not self .stateful :
882
+ pkv_standard = self ._convert_to_standard_cache (past_key_values , batch_size )
883
+ pkv = tuple (tuple (past_state [indicies ] for past_state in layer_past ) for layer_past in pkv_standard )
884
+ past_key_values = self ._convert_to_bloom_cache (pkv )
885
+
886
+ if self .stateful :
887
+ self .next_beam_idx = (
888
+ self .next_beam_idx [indicies ]
889
+ if self .next_beam_idx is not None
890
+ else np .arange (batch_size , dtype = int )[indicies ]
891
+ )
892
+ self ._second_iter_beam_search = True
893
+ return logits , past_key_values
894
+
857
895
858
896
class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
859
897
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
860
898
def _reorder_cache (
861
899
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
862
900
) -> Tuple [Tuple [torch .Tensor ]]:
863
901
if self .stateful :
864
- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
902
+ # save beam_idx to be used as an input in the next iteration
903
+ self .next_beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
904
+ self ._second_iter_beam_search = False
865
905
return past_key_values
866
906
else :
867
907
return tuple (np .take (layer_past , beam_idx , 0 ) for layer_past in past_key_values )
0 commit comments