File tree 1 file changed +5
-4
lines changed
1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -652,18 +652,19 @@ def attention_interface(
652
652
is_causal = True ,
653
653
)
654
654
self .use_sdpa = True
655
- elif self .has_flash_attn (query ) and past_len == 0 :
656
- # prefill, remove padding
655
+ elif self .has_flash_attn (query ):
657
656
attn_output = torch .empty_like (query )
658
657
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
658
+ query_len_tensor = seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ])
659
+ query_max_len = input_lens .max () if past_len == 0 else 1
659
660
PagedAttention .flash_attn_varlen_func (
660
661
attn_output ,
661
662
query .contiguous () if query .device .type == "xpu" else query ,
662
663
key_cache .contiguous () if key_cache .device .type == "xpu" else key_cache ,
663
664
value_cache .contiguous () if value_cache .device .type == "xpu" else value_cache ,
665
+ query_len_tensor ,
664
666
seq_len_tensor ,
665
- seq_len_tensor ,
666
- input_lens .max (),
667
+ query_max_len ,
667
668
input_lens .max (),
668
669
1.0 / math .sqrt (self .head_dim ),
669
670
True ,
You can’t perform that action at this time.
0 commit comments