20
20
21
21
import torch
22
22
import torch .nn .functional as F
23
+ from transformers .cache_utils import Cache , StaticCache
23
24
from transformers .modeling_outputs import BaseModelOutputWithPast
24
25
from transformers .utils import is_tf_available
25
26
@@ -1397,9 +1398,173 @@ def _dbrx_experts_forward(
1397
1398
return out
1398
1399
1399
1400
1401
+ def _dbrx_update_causal_mask_legacy (
1402
+ self , attention_mask : Optional [torch .Tensor ], input_tensor : torch .Tensor , cache_position : torch .Tensor
1403
+ ) -> Optional [torch .Tensor ]:
1404
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
1405
+
1406
+ if self .config ._attn_implementation == "flash_attention_2" :
1407
+ if attention_mask is not None and 0.0 in attention_mask :
1408
+ return attention_mask
1409
+ return None
1410
+
1411
+ dtype , device = input_tensor .dtype , input_tensor .device
1412
+ min_dtype = torch .finfo (torch .float16 ).min
1413
+ sequence_length = input_tensor .shape [1 ]
1414
+ if hasattr (self .blocks [0 ].norm_attn_norm .attn , "past_key_value" ): # static cache
1415
+ target_length = self .config .max_position_embeddings
1416
+ else : # dynamic cache
1417
+ target_length = (
1418
+ attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else cache_position [- 1 ] + 1
1419
+ )
1420
+
1421
+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
1422
+ if sequence_length != 1 :
1423
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
1424
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1425
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
1426
+ if attention_mask is not None :
1427
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1428
+ if attention_mask .dim () == 2 :
1429
+ mask_length = attention_mask .shape [- 1 ]
1430
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
1431
+ padding_mask = padding_mask == 0
1432
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
1433
+ padding_mask , min_dtype
1434
+ )
1435
+ elif attention_mask .dim () == 4 :
1436
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1437
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
1438
+ if attention_mask .shape [- 2 ] < cache_position [0 ] + sequence_length :
1439
+ offset = cache_position [0 ]
1440
+ else :
1441
+ offset = 0
1442
+ mask_shape = attention_mask .shape
1443
+ mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
1444
+ causal_mask [
1445
+ : mask_shape [0 ], : mask_shape [1 ], offset : mask_shape [2 ] + offset , : mask_shape [3 ]
1446
+ ] = mask_slice
1447
+
1448
+ if (
1449
+ self .config ._attn_implementation == "sdpa"
1450
+ and attention_mask is not None
1451
+ and attention_mask .device .type == "cuda"
1452
+ ):
1453
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1454
+ is_tracing = (
1455
+ torch .jit .is_tracing ()
1456
+ or isinstance (input_tensor , torch .fx .Proxy )
1457
+ or (hasattr (torch , "_dynamo" ) and torch ._dynamo .is_compiling ())
1458
+ )
1459
+ if not is_tracing and torch .any (attention_mask != 1 ):
1460
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1461
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1462
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1463
+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
1464
+
1465
+ return causal_mask
1466
+
1467
+
1468
+ def _dbrx_update_causal_mask_latest (
1469
+ self ,
1470
+ attention_mask : torch .Tensor ,
1471
+ input_tensor : torch .Tensor ,
1472
+ cache_position : torch .Tensor ,
1473
+ past_key_values : Cache ,
1474
+ output_attentions : bool ,
1475
+ ):
1476
+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
1477
+
1478
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1479
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1480
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1481
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1482
+
1483
+ if self .config ._attn_implementation == "flash_attention_2" :
1484
+ if attention_mask is not None and 0.0 in attention_mask :
1485
+ return attention_mask
1486
+ return None
1487
+
1488
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1489
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1490
+ # to infer the attention mask.
1491
+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
1492
+ using_static_cache = isinstance (past_key_values , StaticCache )
1493
+
1494
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1495
+ if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
1496
+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
1497
+ attention_mask ,
1498
+ inputs_embeds = input_tensor ,
1499
+ past_key_values_length = past_seen_tokens ,
1500
+ is_training = self .training ,
1501
+ ):
1502
+ return None
1503
+
1504
+ dtype , device = input_tensor .dtype , input_tensor .device
1505
+ # difference with original modeling
1506
+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
1507
+ # during execution on platforms with default lower precision (bfloat16, float16)
1508
+ min_dtype = torch .finfo (torch .float16 ).min
1509
+ sequence_length = input_tensor .shape [1 ]
1510
+ if using_static_cache :
1511
+ target_length = past_key_values .get_max_length ()
1512
+ else :
1513
+ target_length = (
1514
+ attention_mask .shape [- 1 ]
1515
+ if isinstance (attention_mask , torch .Tensor )
1516
+ else past_seen_tokens + sequence_length + 1
1517
+ )
1518
+
1519
+ if attention_mask is not None and attention_mask .dim () == 4 :
1520
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1521
+ if attention_mask .max () != 0 :
1522
+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
1523
+ causal_mask = attention_mask
1524
+ else :
1525
+ # difference with original modeling
1526
+ causal_mask = (
1527
+ torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
1528
+ )
1529
+ if sequence_length != 1 :
1530
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
1531
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1532
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
1533
+ if attention_mask is not None :
1534
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1535
+ mask_length = attention_mask .shape [- 1 ]
1536
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
1537
+ padding_mask = padding_mask == 0
1538
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
1539
+ padding_mask , min_dtype
1540
+ )
1541
+ if (
1542
+ self .config ._attn_implementation == "sdpa"
1543
+ and attention_mask is not None
1544
+ and attention_mask .device .type == "cuda"
1545
+ and not output_attentions
1546
+ ):
1547
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1548
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1549
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1550
+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
1551
+
1552
+ return causal_mask
1553
+
1554
+
1555
+ if is_transformers_version (">" , "4.40.2" ):
1556
+ _dbrx_update_causal_mask = _dbrx_update_causal_mask_latest
1557
+ else :
1558
+ _dbrx_update_causal_mask = _dbrx_update_causal_mask_legacy
1559
+
1560
+
1400
1561
class DBRXModelPatcher (DecoderModelPatcher ):
1401
1562
def __enter__ (self ):
1402
1563
super ().__enter__ ()
1564
+ self ._model .transformer ._orig_update_causal_mask = self ._model .transformer ._update_causal_mask
1565
+ self ._model .transformer ._update_causal_mask = types .MethodType (
1566
+ _dbrx_update_causal_mask , self ._model .transformer
1567
+ )
1403
1568
1404
1569
for block in self ._model .transformer .blocks :
1405
1570
rotary_emb = block .norm_attn_norm .attn .rotary_emb
@@ -1413,5 +1578,6 @@ def __enter__(self):
1413
1578
1414
1579
def __exit__ (self , exc_type , exc_value , traceback ):
1415
1580
super ().__exit__ (exc_type , exc_value , traceback )
1581
+ self ._model .transformer ._update_causal_mask = self ._model .transformer ._orig_update_causal_mask
1416
1582
for block in self ._model .transformer .blocks :
1417
1583
block .ffn .experts .forward = block .ffn .experts ._orig_forward
0 commit comments