@@ -291,11 +291,17 @@ def __exit__(self, exc_type, exc_value, traceback):
291
291
# adopted from
292
292
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293
293
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294
- def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , ** kwargs ):
294
+ def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
295
295
from transformers .modeling_attn_mask_utils import AttentionMaskConverter
296
296
297
- # for compatibility with https://github.com/huggingface/transformers/pull/30047
298
- current_length = kwargs .get ("current_length" , cache_position [- 1 ])
297
+ if self .config ._attn_implementation == "sdpa" and past_seen_tokens is not None :
298
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
299
+ # in order to dispatch on Flash Attention 2.
300
+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
301
+ attention_mask , inputs_embeds = input_tensor , past_key_values_length = past_seen_tokens
302
+ ):
303
+ return None
304
+
299
305
dtype , device = input_tensor .dtype , input_tensor .device
300
306
301
307
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
@@ -305,7 +311,13 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
305
311
if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
306
312
target_length = self .config .max_position_embeddings
307
313
else : # dynamic cache
308
- target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length + 1
314
+ if past_seen_tokens is not None :
315
+ current_length = past_seen_tokens + sequence_length + 1
316
+ # TODO : remove after support of transformers >= v4.40.0
317
+ else :
318
+ current_length = cache_position [- 1 ] + 1
319
+
320
+ target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length
309
321
310
322
causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
311
323
if sequence_length != 1 :
0 commit comments