@@ -559,11 +559,9 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
559
559
if indicies .shape [0 ] != 1 :
560
560
logits = logits [indicies ]
561
561
if past_key_values and not self .stateful :
562
- if (
563
- self .config .model_type not in MULTI_QUERY_ATTN_MODELS
564
- or self .config .model_type == "falcon"
565
- and self .config .new_decoder_architecture
566
- ):
562
+ if (self .config .model_type not in MULTI_QUERY_ATTN_MODELS
563
+ or (self .config .model_type == "falcon"
564
+ and self .config .new_decoder_architecture )):
567
565
past_key_values = tuple (
568
566
tuple (
569
567
past_state [indicies ]
@@ -581,7 +579,7 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke
581
579
if self .next_beam_idx is not None
582
580
else np .arange (batch_size , dtype = int )[indicies ]
583
581
)
584
- self ._second_iter_beam_search = True
582
+ self ._second_iter_beam_search = True
585
583
return logits , past_key_values
586
584
587
585
def _deduplicate_inputs (self , model_inputs : Dict ):
@@ -692,7 +690,7 @@ def _reorder_cache(
692
690
self ._second_iter_beam_search = False
693
691
return past_key_values
694
692
else :
695
- if self .config .model_type not in MULTI_QUERY_ATTN_MODELS and not (
693
+ if self .config .model_type not in MULTI_QUERY_ATTN_MODELS or (
696
694
self .config .model_type == "falcon" and self .config .new_decoder_architecture
697
695
):
698
696
return tuple (
0 commit comments