28
28
from transformers import AutoModelForCausalLM , PretrainedConfig
29
29
from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
30
30
from transformers .generation import GenerationMixin
31
- from transformers .generation .configuration_utils import GenerationConfig , GenerationMode
31
+ from transformers .generation .configuration_utils import GenerationConfig
32
32
from transformers .generation .logits_process import LogitsProcessorList
33
33
from transformers .generation .stopping_criteria import StoppingCriteriaList
34
- from transformers .generation .utils import GenerateOutput
34
+ from transformers .generation .utils import GenerateOutput , GenerationMode
35
35
from transformers .modeling_outputs import CausalLMOutputWithPast
36
36
37
37
from optimum .utils .normalized_config import NormalizedConfigManager
@@ -386,10 +386,8 @@ def prepare_inputs(
386
386
inputs = {}
387
387
if not self .stateful :
388
388
if past_key_values is not None :
389
- if (
390
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
391
- or self .config .model_type == "falcon"
392
- and self .config .new_decoder_architecture
389
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
390
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
393
391
):
394
392
if self ._pkv_precision == Type .bf16 :
395
393
# numpy does not support bf16, pretending f16, should change to bf16
@@ -499,10 +497,8 @@ def forward(
499
497
if self .use_cache :
500
498
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
501
499
past_key_values = tuple (self .request .get_tensor (key ).data for key in self .key_value_output_names )
502
- if (
503
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
504
- or self .config .model_type == "falcon"
505
- and self .config .new_decoder_architecture
500
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
501
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
506
502
):
507
503
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
508
504
past_key_values = tuple (
@@ -559,10 +555,8 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
559
555
if indicies .shape [0 ] != 1 :
560
556
logits = logits [indicies ]
561
557
if past_key_values and not self .stateful :
562
- if (
563
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564
- or self .config .model_type == "falcon"
565
- and self .config .new_decoder_architecture
558
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
559
+ self .config .model_type == "falcon" and self .config .new_decoder_architecture
566
560
):
567
561
past_key_values = tuple (
568
562
tuple (
@@ -581,7 +575,7 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
581
575
if self .next_beam_idx is not None
582
576
else np .arange (batch_size , dtype = int )[indicies ]
583
577
)
584
- self ._second_iter_beam_search = True
578
+ self ._second_iter_beam_search = True
585
579
return logits , past_key_values
586
580
587
581
def _deduplicate_inputs (self , model_inputs : Dict ):
@@ -692,7 +686,7 @@ def _reorder_cache(
692
686
self ._second_iter_beam_search = False
693
687
return past_key_values
694
688
else :
695
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
689
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
696
690
self .config .model_type == "falcon" and self .config .new_decoder_architecture
697
691
):
698
692
return tuple (
0 commit comments