@@ -301,7 +301,7 @@ def __exit__(self, exc_type, exc_value, traceback):
301
301
# adopted from
302
302
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
303
303
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
304
- def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
304
+ def _llama_gemma_update_causal_mask_legacy (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
305
305
from transformers .modeling_attn_mask_utils import AttentionMaskConverter
306
306
307
307
if self .config ._attn_implementation == "sdpa" and past_seen_tokens is not None :
@@ -314,10 +314,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
314
314
315
315
dtype , device = input_tensor .dtype , input_tensor .device
316
316
317
+ # difference with original modeling
317
318
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
318
319
# during execution on platforms with default lower precision (bfloat16, float16)
319
320
min_dtype = torch .finfo (torch .float16 ).min
320
321
sequence_length = input_tensor .shape [1 ]
322
+ # difference with original modeling
321
323
if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
322
324
target_length = self .config .max_position_embeddings
323
325
else : # dynamic cache
@@ -329,7 +331,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
329
331
330
332
target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length
331
333
334
+ # difference with original modeling
332
335
causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
336
+
333
337
if sequence_length != 1 :
334
338
causal_mask = torch .triu (causal_mask , diagonal = 1 )
335
339
causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
@@ -366,6 +370,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
366
370
return causal_mask
367
371
368
372
373
+ # adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
374
+ def _llama_gemma_update_causal_mask_latest (
375
+ self ,
376
+ attention_mask ,
377
+ input_tensor ,
378
+ cache_position ,
379
+ past_key_values ,
380
+ output_attentions ,
381
+ ):
382
+ from transformers .cache_utils import StaticCache
383
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
384
+
385
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
386
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
387
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
388
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
389
+
390
+ if self .config ._attn_implementation == "flash_attention_2" :
391
+ if attention_mask is not None and 0.0 in attention_mask :
392
+ return attention_mask
393
+ return None
394
+
395
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
396
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
397
+ # to infer the attention mask.
398
+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
399
+ using_static_cache = isinstance (past_key_values , StaticCache )
400
+
401
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
402
+ if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
403
+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
404
+ attention_mask ,
405
+ inputs_embeds = input_tensor ,
406
+ past_key_values_length = past_seen_tokens ,
407
+ is_training = self .training ,
408
+ ):
409
+ return None
410
+
411
+ dtype , device = input_tensor .dtype , input_tensor .device
412
+ # difference with original modeling
413
+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
414
+ # during execution on platforms with default lower precision (bfloat16, float16)
415
+ min_dtype = torch .finfo (torch .float16 ).min
416
+
417
+ sequence_length = input_tensor .shape [1 ]
418
+ if using_static_cache :
419
+ target_length = past_key_values .get_max_length ()
420
+ else :
421
+ target_length = (
422
+ attention_mask .shape [- 1 ]
423
+ if isinstance (attention_mask , torch .Tensor )
424
+ else past_seen_tokens + sequence_length + 1
425
+ )
426
+
427
+ if attention_mask is not None and attention_mask .dim () == 4 :
428
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
429
+ if attention_mask .max () != 0 :
430
+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
431
+ causal_mask = attention_mask
432
+ else :
433
+ # difference with original modeling
434
+ causal_mask = (
435
+ torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
436
+ )
437
+
438
+ if sequence_length != 1 :
439
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
440
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
441
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
442
+ if attention_mask is not None :
443
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
444
+ mask_length = attention_mask .shape [- 1 ]
445
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
446
+ padding_mask = padding_mask == 0
447
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
448
+ padding_mask , min_dtype
449
+ )
450
+ if (
451
+ self .config ._attn_implementation == "sdpa"
452
+ and attention_mask is not None
453
+ and attention_mask .device .type == "cuda"
454
+ and not output_attentions
455
+ ):
456
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
457
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
458
+ # Details: https://github.com/pytorch/pytorch/issues/110213
459
+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
460
+
461
+ return causal_mask
462
+
463
+
464
+ # TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
465
+ if is_transformers_version (">" , "4.40.2" ):
466
+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
467
+ else :
468
+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
469
+
470
+
369
471
class GemmaModelPatcher (DecoderModelPatcher ):
370
472
def __enter__ (self ):
371
473
super ().__enter__ ()
0 commit comments