File tree 1 file changed +21
-7
lines changed
1 file changed +21
-7
lines changed Original file line number Diff line number Diff line change @@ -614,20 +614,34 @@ def forward(
614
614
if past_len == 0 :
615
615
# prefill, remove padding
616
616
seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
617
- varlen_attention (
618
- query .contiguous () if query .device .type == "xpu" else query ,
619
- key .contiguous () if key .device .type == "xpu" else key ,
620
- value .contiguous () if value .device .type == "xpu" else value ,
617
+ # varlen_attention(
618
+ # query.contiguous() if query.device.type == "xpu" else query,
619
+ # key.contiguous() if key.device.type == "xpu" else key,
620
+ # value.contiguous() if value.device.type == "xpu" else value,
621
+ # attn_output,
622
+ # seq_len_tensor,
623
+ # seq_len_tensor,
624
+ # input_lens.max(),
625
+ # input_lens.max(),
626
+ # 0.0,
627
+ # 1.0 / math.sqrt(self.head_dim),
628
+ # False,
629
+ # True,
630
+ # False,
631
+ # None,
632
+ # )
633
+ PagedAttention .flash_attn_varlen_func (
621
634
attn_output ,
635
+ query ,
636
+ key_cache ,
637
+ value_cache ,
622
638
seq_len_tensor ,
623
639
seq_len_tensor ,
624
640
input_lens .max (),
625
641
input_lens .max (),
626
- 0.0 ,
627
642
1.0 / math .sqrt (self .head_dim ),
628
- False ,
629
643
True ,
630
- False ,
644
+ past_key_value . block_tables ,
631
645
None ,
632
646
)
633
647
else :
You can’t perform that action at this time.
0 commit comments