@@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_value, traceback):
293
293
# adopted from
294
294
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
295
295
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
296
- def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
296
+ def _llama_gemma_update_causal_mask_legacy (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
297
297
from transformers .modeling_attn_mask_utils import AttentionMaskConverter
298
298
299
299
if self .config ._attn_implementation == "sdpa" and past_seen_tokens is not None :
@@ -306,10 +306,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
306
306
307
307
dtype , device = input_tensor .dtype , input_tensor .device
308
308
309
+ # difference with original modeling
309
310
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
310
311
# during execution on platforms with default lower precision (bfloat16, float16)
311
312
min_dtype = torch .finfo (torch .float16 ).min
312
313
sequence_length = input_tensor .shape [1 ]
314
+ # difference with original modeling
313
315
if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
314
316
target_length = self .config .max_position_embeddings
315
317
else : # dynamic cache
@@ -321,7 +323,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
321
323
322
324
target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length
323
325
326
+ # difference with original modeling
324
327
causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
328
+
325
329
if sequence_length != 1 :
326
330
causal_mask = torch .triu (causal_mask , diagonal = 1 )
327
331
causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
@@ -358,6 +362,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
358
362
return causal_mask
359
363
360
364
365
+ # adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
366
+ def _llama_gemma_update_causal_mask_latest (
367
+ self ,
368
+ attention_mask ,
369
+ input_tensor ,
370
+ cache_position ,
371
+ past_key_values ,
372
+ output_attentions ,
373
+ ):
374
+ from transformers .cache_utils import StaticCache
375
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
376
+
377
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
378
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
379
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
380
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
381
+
382
+ if self .config ._attn_implementation == "flash_attention_2" :
383
+ if attention_mask is not None and 0.0 in attention_mask :
384
+ return attention_mask
385
+ return None
386
+
387
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
388
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
389
+ # to infer the attention mask.
390
+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
391
+ using_static_cache = isinstance (past_key_values , StaticCache )
392
+
393
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
394
+ if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
395
+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
396
+ attention_mask ,
397
+ inputs_embeds = input_tensor ,
398
+ past_key_values_length = past_seen_tokens ,
399
+ is_training = self .training ,
400
+ ):
401
+ return None
402
+
403
+ dtype , device = input_tensor .dtype , input_tensor .device
404
+ # difference with original modeling
405
+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
406
+ # during execution on platforms with default lower precision (bfloat16, float16)
407
+ min_dtype = torch .finfo (torch .float16 ).min
408
+
409
+ sequence_length = input_tensor .shape [1 ]
410
+ if using_static_cache :
411
+ target_length = past_key_values .get_max_length ()
412
+ else :
413
+ target_length = (
414
+ attention_mask .shape [- 1 ]
415
+ if isinstance (attention_mask , torch .Tensor )
416
+ else past_seen_tokens + sequence_length + 1
417
+ )
418
+
419
+ if attention_mask is not None and attention_mask .dim () == 4 :
420
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
421
+ if attention_mask .max () != 0 :
422
+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
423
+ causal_mask = attention_mask
424
+ else :
425
+ # difference with original modeling
426
+ causal_mask = (
427
+ torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
428
+ )
429
+
430
+ if sequence_length != 1 :
431
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
432
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
433
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
434
+ if attention_mask is not None :
435
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
436
+ mask_length = attention_mask .shape [- 1 ]
437
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
438
+ padding_mask = padding_mask == 0
439
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
440
+ padding_mask , min_dtype
441
+ )
442
+ if (
443
+ self .config ._attn_implementation == "sdpa"
444
+ and attention_mask is not None
445
+ and attention_mask .device .type == "cuda"
446
+ and not output_attentions
447
+ ):
448
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
449
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
450
+ # Details: https://github.com/pytorch/pytorch/issues/110213
451
+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
452
+
453
+ return causal_mask
454
+
455
+
456
+ # TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
457
+ if is_transformers_version (">=" , "4.41.0" ):
458
+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
459
+ else :
460
+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
461
+
462
+
361
463
class GemmaModelPatcher (DecoderModelPatcher ):
362
464
def __enter__ (self ):
363
465
super ().__enter__ ()
0 commit comments