@@ -2501,6 +2501,40 @@ def __enter__(self):
2501
2501
_reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
2502
2502
2503
2503
2504
+ # Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138
2505
+ def _falcon_prepare_4d_causal_attention_mask_with_cache_position (
2506
+ attention_mask : torch .Tensor ,
2507
+ sequence_length : int ,
2508
+ target_length : int ,
2509
+ dtype : torch .dtype ,
2510
+ device : torch .device ,
2511
+ cache_position : torch .Tensor ,
2512
+ batch_size : int ,
2513
+ ** kwargs ,
2514
+ ):
2515
+ if attention_mask is not None and attention_mask .dim () == 4 :
2516
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2517
+ causal_mask = attention_mask
2518
+ else :
2519
+ # different from original: allow to provide min_dtype as parameter
2520
+ min_dtype = torch .finfo (dtype ).min if "min_dtype" not in kwargs else kwargs ["min_dtype" ]
2521
+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
2522
+ if sequence_length != 1 :
2523
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
2524
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
2525
+ causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
2526
+ if attention_mask is not None :
2527
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
2528
+ mask_length = attention_mask .shape [- 1 ]
2529
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
2530
+ padding_mask = padding_mask == 0
2531
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
2532
+ padding_mask , min_dtype
2533
+ )
2534
+
2535
+ return causal_mask
2536
+
2537
+
2504
2538
def _falcon_update_causal_mask (
2505
2539
self ,
2506
2540
attention_mask : torch .Tensor ,
@@ -2520,13 +2554,6 @@ def _falcon_update_causal_mask(
2520
2554
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
2521
2555
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
2522
2556
2523
- if hasattr (self , "_prepare_4d_causal_attention_mask_with_cache_position" ):
2524
- _prepare_4d_causal_attention_mask_with_cache_position = (
2525
- self ._prepare_4d_causal_attention_mask_with_cache_position
2526
- )
2527
- else :
2528
- from transformers .models .falcon .modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position
2529
-
2530
2557
if self .config ._attn_implementation == "flash_attention_2" :
2531
2558
if attention_mask is not None and 0.0 in attention_mask :
2532
2559
return attention_mask
@@ -2568,7 +2595,7 @@ def _falcon_update_causal_mask(
2568
2595
)
2569
2596
2570
2597
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2571
- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position (
2598
+ causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position (
2572
2599
attention_mask ,
2573
2600
sequence_length = sequence_length ,
2574
2601
target_length = target_length ,
0 commit comments