@@ -531,9 +531,10 @@ def _qwen_attention_forward(
531
531
value = value .permute (0 , 2 , 1 , 3 )
532
532
533
533
if not self .use_cache_quantization and SUPPORT_SDPA :
534
- causal_mask = registered_causal_mask [:, :, key .size (- 2 ) - query .size (- 2 ) : key .size (- 2 ), : key .size (- 2 )]
534
+ # For performance, using constant tril to generate causal_mask
535
+ causal_mask = self .bias [:, :, key .size (- 2 ) - query .size (- 2 ) : key .size (- 2 ), : key .size (- 2 )]
535
536
if attention_mask is not None :
536
- attention_mask = attention_mask .expand (- 1 , - 1 , causal_mask .size (2 ), - 1 ).masked_fill (
537
+ attention_mask = attention_mask .expand (- 1 , - 1 , query .size (2 ), - 1 ).masked_fill (
537
538
~ causal_mask , torch .finfo (query .dtype ).min
538
539
)
539
540
else :
@@ -578,8 +579,17 @@ def __init__(
578
579
579
580
def __enter__ (self ):
580
581
super ().__enter__ ()
582
+ max_positions = self ._model .config .seq_length
581
583
for block in self ._model .transformer .h :
582
584
block .attn ._orig_forward = block .attn .forward
585
+ # For performance, using constant tril to generate causal_mask
586
+ block .attn .register_buffer (
587
+ "bias" ,
588
+ torch .tril (torch .ones ((max_positions , max_positions ), dtype = torch .bool )).view (
589
+ 1 , 1 , max_positions , max_positions
590
+ ),
591
+ persistent = False ,
592
+ )
583
593
block .attn .forward = types .MethodType (_qwen_attention_forward , block .attn )
584
594
585
595
def __exit__ (self , exc_type , exc_value , traceback ):
0 commit comments