@@ -1468,6 +1468,7 @@ def _dbrx_experts_forward(
1468
1468
return out
1469
1469
1470
1470
1471
+ # adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
1471
1472
def _dbrx_update_causal_mask_legacy (
1472
1473
self , attention_mask : Optional [torch .Tensor ], input_tensor : torch .Tensor , cache_position : torch .Tensor
1473
1474
) -> Optional [torch .Tensor ]:
@@ -1479,6 +1480,9 @@ def _dbrx_update_causal_mask_legacy(
1479
1480
return None
1480
1481
1481
1482
dtype , device = input_tensor .dtype , input_tensor .device
1483
+ # difference with original modeling
1484
+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
1485
+ # during execution on platforms with default lower precision (bfloat16, float16)
1482
1486
min_dtype = torch .finfo (torch .float16 ).min
1483
1487
sequence_length = input_tensor .shape [1 ]
1484
1488
if hasattr (self .blocks [0 ].norm_attn_norm .attn , "past_key_value" ): # static cache
@@ -1487,7 +1491,9 @@ def _dbrx_update_causal_mask_legacy(
1487
1491
target_length = (
1488
1492
attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else cache_position [- 1 ] + 1
1489
1493
)
1490
-
1494
+ # difference with original modeling
1495
+ # removed target_length = int(target_length).
1496
+ # Casting to int leads to constant folding during tracing that makes impossible to use model for sequence of different length
1491
1497
causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
1492
1498
if sequence_length != 1 :
1493
1499
causal_mask = torch .triu (causal_mask , diagonal = 1 )
@@ -1535,6 +1541,7 @@ def _dbrx_update_causal_mask_legacy(
1535
1541
return causal_mask
1536
1542
1537
1543
1544
+ # adopted from https://github.com/huggingface/transformers/blob/1b3dba9417eebe16b7c206d1dfca6a4c7f11dbec/src/transformers/models/dbrx/modeling_dbrx.py#L1204
1538
1545
def _dbrx_update_causal_mask_latest (
1539
1546
self ,
1540
1547
attention_mask : torch .Tensor ,
@@ -1631,18 +1638,22 @@ def _dbrx_update_causal_mask_latest(
1631
1638
class DBRXModelPatcher (DecoderModelPatcher ):
1632
1639
def __enter__ (self ):
1633
1640
super ().__enter__ ()
1641
+ # dbrx has some accuracy issues with bf16 with transformers >= 4.40
1642
+ # fill causal mask in slightly different way for avoid overflow on some platforms
1634
1643
self ._model .transformer ._orig_update_causal_mask = self ._model .transformer ._update_causal_mask
1635
1644
self ._model .transformer ._update_causal_mask = types .MethodType (
1636
1645
_dbrx_update_causal_mask , self ._model .transformer
1637
1646
)
1638
1647
1639
1648
for block in self ._model .transformer .blocks :
1640
1649
rotary_emb = block .norm_attn_norm .attn .rotary_emb
1650
+ # initialize inv_freq for torchscript tracing
1641
1651
if rotary_emb .inv_freq is None :
1642
1652
inv_freq = 1.0 / (
1643
1653
rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
1644
1654
)
1645
1655
rotary_emb .inv_freq = inv_freq
1656
+ # remove continue-operator from iteration loop over experts
1646
1657
block .ffn .experts ._orig_forward = block .ffn .experts .forward
1647
1658
block .ffn .experts .forward = types .MethodType (_dbrx_experts_forward , block .ffn .experts )
1648
1659
0 commit comments