@@ -699,6 +699,7 @@ def gpt_bigcode_wrapped_scaled_dot_product(
699
699
# MHA models: (batch_size, num_heads, query_length, head_dim)
700
700
query_shape = query .shape
701
701
batch_size = query_shape [0 ]
702
+ kv_seq_len = key .shape [- 2 ]
702
703
703
704
if self .multi_query :
704
705
query_length = query_shape [1 ]
@@ -725,30 +726,34 @@ def gpt_bigcode_wrapped_scaled_dot_product(
725
726
key = key .expand (- 1 , self .num_heads , - 1 , - 1 )
726
727
value = value .expand (- 1 , self .num_heads , - 1 , - 1 )
727
728
728
- if batch_size == 1 or self .training :
729
- if query_length > 1 :
730
- sdpa_result = torch .nn .functional .scaled_dot_product_attention (
731
- query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
732
- )
733
- else :
734
- sdpa_result = torch .nn .functional .scaled_dot_product_attention (
735
- query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False
736
- )
729
+ # We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
730
+ if self .training :
731
+ is_causal = True
732
+ attn_mask = None
733
+ elif batch_size == 1 and query_length == 1 :
734
+ is_causal = False
735
+ attn_mask = None
736
+ elif batch_size == 1 and kv_seq_len == query_length :
737
+ is_causal = True
738
+ attn_mask = None
739
+ elif attention_mask is not None :
740
+ mask_value = self ._get_mask_value (query .device , query .dtype )
741
+
742
+ # gpt_bigcode has the bad taste to use a causal mask a
743
+ # [batch_size, target_length, 1, source_length] which is different from
744
+ # **all** other architectures and not compatible with SDPA.
745
+ # We could avoid this transpose by overriding the forward from GPTBigCodeModel,
746
+ # but it is probably not worth it.
747
+ attention_mask = attention_mask .transpose (1 , 2 )
748
+ attn_mask = torch .where (attention_mask , 0.0 , mask_value )
749
+ is_causal = False
737
750
else :
738
- if attention_mask is not None :
739
- mask_value = self . _get_mask_value ( query . device , query . dtype )
751
+ attn_mask = None
752
+ is_causal = True
740
753
741
- # gpt_bigcode has the bad taste to use a causal mask a
742
- # [batch_size, target_length, 1, source_length] which is different from
743
- # **all** other architectures and not compatible with SDPA.
744
- # We could avoid this transpose by overriding the forward from GPTBigCodeModel,
745
- # but it is probably not worth it.
746
- attention_mask = attention_mask .transpose (1 , 2 )
747
- attention_mask = torch .where (attention_mask , 0.0 , mask_value )
748
-
749
- sdpa_result = torch .nn .functional .scaled_dot_product_attention (
750
- query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False
751
- )
754
+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
755
+ query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal
756
+ )
752
757
753
758
if self .multi_query :
754
759
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
0 commit comments