12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import copy
15
+ import warnings
15
16
import logging
16
17
import os
17
- import warnings
18
18
from pathlib import Path
19
19
from tempfile import TemporaryDirectory
20
- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
20
+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
21
21
22
22
import numpy as np
23
23
import openvino
28
28
from transformers import AutoModelForCausalLM , PretrainedConfig
29
29
from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
30
30
from transformers .generation import GenerationMixin
31
- from transformers .generation .beam_search import BeamScorer , ConstrainedBeamSearchScorer
32
31
from transformers .generation .configuration_utils import GenerationConfig , GenerationMode
33
32
from transformers .generation .logits_process import LogitsProcessorList
34
- from transformers .generation .stopping_criteria import (
35
- EosTokenCriteria ,
36
- StoppingCriteriaList ,
37
- validate_stopping_criteria ,
38
- )
39
- from transformers .generation .utils import (
40
- GenerateBeamDecoderOnlyOutput ,
41
- GenerateBeamOutput ,
42
- GenerateOutput ,
43
- _split_model_inputs ,
44
- stack_model_outputs ,
45
- )
33
+ from transformers .generation .stopping_criteria import StoppingCriteriaList
34
+ from transformers .generation .utils import GenerateOutput
46
35
from transformers .modeling_outputs import CausalLMOutputWithPast
47
36
48
37
from optimum .utils .normalized_config import NormalizedConfigManager
@@ -398,7 +387,11 @@ def prepare_inputs(
398
387
inputs = {}
399
388
if not self .stateful :
400
389
if past_key_values is not None :
401
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
390
+ if (
391
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
392
+ or self .config .model_type == "falcon"
393
+ and self .config .new_decoder_architecture
394
+ ):
402
395
if self ._pkv_precision == Type .bf16 :
403
396
# numpy does not support bf16, pretending f16, should change to bf16
404
397
past_key_values = tuple (
@@ -491,9 +484,6 @@ def forward(
491
484
position_ids = position_ids ,
492
485
** kwargs ,
493
486
)
494
-
495
- print (inputs ["input_ids" ].shape )
496
-
497
487
# Run inference
498
488
self .request .start_async (inputs , share_inputs = True )
499
489
self .request .wait ()
@@ -509,7 +499,11 @@ def forward(
509
499
if self .use_cache :
510
500
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
511
501
past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
512
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS :
502
+ if (
503
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504
+ or self .config .model_type == "falcon"
505
+ and self .config .new_decoder_architecture
506
+ ):
513
507
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
514
508
past_key_values = tuple (
515
509
past_key_values [i : i + self .num_pkv ] for i in range (0 , len (past_key_values ), self .num_pkv )
@@ -561,21 +555,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
561
555
return model_inputs
562
556
563
557
def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
558
+ batch_size = logits .shape [0 ]
564
559
if indicies .shape [0 ] != 1 :
565
560
logits = logits [indicies ]
566
561
if past_key_values and not self .stateful :
567
- past_key_values = tuple (
568
- tuple (
569
- past_state [indicies ]
570
- if not self .config .model_type == "chatglm"
571
- else past_state [:, indicies , ...]
572
- for past_state in layer_past
562
+ if (
563
+ self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564
+ or self .config .model_type == "falcon"
565
+ and self .config .new_decoder_architecture
566
+ ):
567
+ past_key_values = tuple (
568
+ tuple (
569
+ past_state [indicies ]
570
+ if not self .config .model_type == "chatglm"
571
+ else past_state [:, indicies , ...]
572
+ for past_state in layer_past
573
+ )
574
+ for layer_past in past_key_values
573
575
)
574
- for layer_past in past_key_values
575
- )
576
- if self .stateful :
577
- self .next_beam_idx = self .next_beam_idx [indicies ]
578
- self ._second_iter_beam_search = True
576
+ else :
577
+ past_key_values = tuple ([past_state [indicies ] for past_state in past_key_values ])
578
+ if self .stateful :
579
+ self .next_beam_idx = (
580
+ self .next_beam_idx [indicies ]
581
+ if self .next_beam_idx is not None
582
+ else np .arange (batch_size , dtype = int )[indicies ]
583
+ )
584
+ self ._second_iter_beam_search = True
579
585
return logits , past_key_values
580
586
581
587
def _deduplicate_inputs (self , model_inputs : Dict ):
@@ -591,12 +597,19 @@ def _deduplicate_inputs(self, model_inputs: Dict):
591
597
else :
592
598
shape = input_tensor .shape
593
599
dtype = input_tensor .element_type
594
- shape [0 if not self .config .model_type == "chatglm" else 1 ] = indicies .shape [0 ]
600
+ upd_batch_size = indicies .shape [0 ]
601
+ if self .config .model_type == "bloom" :
602
+ upd_batch_size *= self .config .num_attention_heads
603
+ shape [0 if not self .config .model_type == "chatglm" else 1 ] = upd_batch_size
595
604
upd_model_inputs [input_name ] = Tensor (dtype , shape )
596
- print (f"{ input_name } : { upd_model_inputs [input_name ].shape } " )
597
605
upd_model_inputs ["input_ids" ] = unique_input_ids
598
606
if "beam_idx" in model_inputs :
599
- beam_idx = np .arange (unique_input_ids .shape [0 ], dtype = int )
607
+ beam_range = (
608
+ unique_input_ids .shape [0 ]
609
+ if self .config .model_type != "bloom"
610
+ else unique_input_ids .shape [0 ] * self .config .num_attention_heads
611
+ )
612
+ beam_idx = np .arange (beam_range , dtype = int )
600
613
upd_model_inputs ["beam_idx" ] = beam_idx
601
614
return upd_model_inputs , reverse_indicies
602
615
@@ -646,7 +659,9 @@ def _get_past_length(self, past_key_values=None):
646
659
return 0
647
660
if self .stateful :
648
661
return self ._past_length
649
- if self .config .model_type in MULTI_QUERY_ATTN_MODELS :
662
+ if self .config .model_type in MULTI_QUERY_ATTN_MODELS and not (
663
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
664
+ ):
650
665
return past_key_values [0 ].shape [- 2 ]
651
666
seq_length_dim = - 2
652
667
if self .config .model_type == "chatglm" :
@@ -677,9 +692,14 @@ def _reorder_cache(
677
692
self ._second_iter_beam_search = False
678
693
return past_key_values
679
694
else :
680
- return tuple (
681
- tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past ) for layer_past in past_key_values
682
- )
695
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
696
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
697
+ ):
698
+ return tuple (
699
+ tuple (np .take (past_state , beam_idx , 0 ) for past_state in layer_past )
700
+ for layer_past in past_key_values
701
+ )
702
+ return tuple (np .take (past_state , beam_idx , 0 ) for past_state in past_key_values )
683
703
684
704
def can_generate (self ):
685
705
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
@@ -800,11 +820,12 @@ def _reorder_cache(
800
820
This is required to match `past_key_values` with the correct beam_idx at every generation step.
801
821
"""
802
822
if self .stateful :
803
- beam_idx = np .array (beam_idx )
804
823
batch_size = beam_idx .shape [0 ]
824
+ beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
805
825
indices = np .array (range (batch_size * self .config .num_attention_heads ))
806
826
indices = indices .reshape ([batch_size , self .config .num_attention_heads ])
807
827
self .next_beam_idx = np .take (indices , beam_idx , 0 ).flatten ()
828
+ self ._second_iter_beam_search = False
808
829
return past_key_values
809
830
else :
810
831
standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
@@ -854,14 +875,34 @@ def _convert_to_standard_cache(
854
875
for layer_past in past_key_value
855
876
)
856
877
878
+ def _expand_outputs_for_generation (self , indicies , logits : torch .Tensor , past_key_values : Tuple ):
879
+ batch_size = logits .shape [0 ]
880
+ if indicies .shape [0 ] != 1 :
881
+ logits = logits [indicies ]
882
+ if past_key_values and not self .stateful :
883
+ pkv_standard = self ._convert_to_standard_cache (past_key_values , batch_size )
884
+ pkv = tuple (tuple (past_state [indicies ] for past_state in layer_past ) for layer_past in pkv_standard )
885
+ past_key_values = self ._convert_to_bloom_cache (pkv )
886
+
887
+ if self .stateful :
888
+ self .next_beam_idx = (
889
+ self .next_beam_idx [indicies ]
890
+ if self .next_beam_idx is not None
891
+ else np .arange (batch_size , dtype = int )[indicies ]
892
+ )
893
+ self ._second_iter_beam_search = True
894
+ return logits , past_key_values
895
+
857
896
858
897
class OVGPTBigCodeForCausalLM (OVModelForCausalLM ):
859
898
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
860
899
def _reorder_cache (
861
900
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
862
901
) -> Tuple [Tuple [torch .Tensor ]]:
863
902
if self .stateful :
864
- self .next_beam_idx = np .array (beam_idx ) # save beam_idx to be used as an input in the next iteration
903
+ # save beam_idx to be used as an input in the next iteration
904
+ self .next_beam_idx = np .array (beam_idx ) if not self ._second_iter_beam_search else self .next_beam_idx
905
+ self ._second_iter_beam_search = False
865
906
return past_key_values
866
907
else :
867
908
return tuple (np .take (layer_past , beam_idx , 0 ) for layer_past in past_key_values )
0 commit comments