File tree 1 file changed +3
-5
lines changed
1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -652,19 +652,17 @@ def attention_interface(
652
652
is_causal = True ,
653
653
)
654
654
self .use_sdpa = True
655
- elif self .has_flash_attn (query ):
655
+ elif self .has_flash_attn (query ) and past_len == 0 :
656
656
attn_output = torch .empty_like (query )
657
657
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
658
- query_len_tensor = seq_len_tensor if past_len == 0 else torch .arange (seq_len_tensor .shape [0 ]).int ()
659
- query_max_len = input_lens .max () if past_len == 0 else 1
660
658
PagedAttention .flash_attn_varlen_func (
661
659
attn_output ,
662
660
query .contiguous () if query .device .type == "xpu" else query ,
663
661
key_cache .contiguous () if key_cache .device .type == "xpu" else key_cache ,
664
662
value_cache .contiguous () if value_cache .device .type == "xpu" else value_cache ,
665
- query_len_tensor ,
666
663
seq_len_tensor ,
667
- query_max_len ,
664
+ seq_len_tensor ,
665
+ input_lens .max (),
668
666
input_lens .max (),
669
667
1.0 / math .sqrt (self .head_dim ),
670
668
True ,
You can’t perform that action at this time.
0 commit comments