@@ -634,16 +634,31 @@ def has_flash_attn(self, query):
634
634
elif query .device .type == "xpu" :
635
635
return is_torch_version (">" , "2.5.99" )
636
636
637
- def varlen_attn (self , query , key , value , past_key_value , input_lens ):
638
- # prefill, remove padding
639
- attn_output = torch .empty_like (query )
640
- seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
641
- if self .has_flash_attn (query ):
637
+ def prefill_attn (self , query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens ):
638
+ if past_key_value is None :
639
+ n_rep = query .shape [1 ] // key .shape [1 ]
640
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
641
+ query .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
642
+ key .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , key .shape [- 1 ])
643
+ .transpose (1 , 2 )
644
+ .repeat_interleave (n_rep , 1 ),
645
+ value .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , value .shape [- 1 ])
646
+ .transpose (1 , 2 )
647
+ .repeat_interleave (n_rep , 1 ),
648
+ attn_mask = attention_mask ,
649
+ dropout_p = 0.0 ,
650
+ is_causal = True ,
651
+ )
652
+ self .use_sdpa = True
653
+ elif self .has_flash_attn (query ):
654
+ # prefill, remove padding
655
+ attn_output = torch .empty_like (query )
656
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
642
657
PagedAttention .flash_attn_varlen_func (
643
658
attn_output ,
644
659
query ,
645
- key ,
646
- value ,
660
+ key_cache ,
661
+ value_cache ,
647
662
seq_len_tensor ,
648
663
seq_len_tensor ,
649
664
input_lens .max (),
@@ -654,6 +669,9 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens):
654
669
None ,
655
670
)
656
671
else :
672
+ # prefill, remove padding
673
+ attn_output = torch .empty_like (query )
674
+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
657
675
varlen_attention (
658
676
query .contiguous () if query .device .type == "xpu" else query ,
659
677
key .contiguous () if key .device .type == "xpu" else key ,
@@ -697,23 +715,9 @@ def forward(
697
715
698
716
if past_len == 0 :
699
717
# prefill
700
- if past_key_value is None :
701
- n_rep = query .shape [1 ] // key .shape [1 ]
702
- attn_output = torch .nn .functional .scaled_dot_product_attention (
703
- query .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , query .shape [- 1 ]).transpose (1 , 2 ),
704
- key .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , key .shape [- 1 ])
705
- .transpose (1 , 2 )
706
- .repeat_interleave (n_rep , 1 ),
707
- value .reshape (input_lens .shape [0 ], input_lens .max ().item (), - 1 , value .shape [- 1 ])
708
- .transpose (1 , 2 )
709
- .repeat_interleave (n_rep , 1 ),
710
- attn_mask = attention_mask ,
711
- dropout_p = 0.0 ,
712
- is_causal = True ,
713
- )
714
- self .use_sdpa = True
715
- else :
716
- attn_output = self .varlen_attn (query , key_cache , value_cache , past_key_value , input_lens )
718
+ attn_output = self .prefill_attn (
719
+ query , key_cache , value_cache , key , value , past_key_value , attention_mask , input_lens
720
+ )
717
721
else :
718
722
# decode
719
723
attn_output = torch .empty_like (query )
0 commit comments