Skip to content

Commit 6d21075

Browse files
committed
use real varlen attn
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 41d9a37 commit 6d21075

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

optimum/exporters/ipex/modeling_utils.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -614,20 +614,34 @@ def forward(
614614
if past_len == 0:
615615
# prefill, remove padding
616616
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
617-
varlen_attention(
618-
query.contiguous() if query.device.type == "xpu" else query,
619-
key.contiguous() if key.device.type == "xpu" else key,
620-
value.contiguous() if value.device.type == "xpu" else value,
617+
# varlen_attention(
618+
# query.contiguous() if query.device.type == "xpu" else query,
619+
# key.contiguous() if key.device.type == "xpu" else key,
620+
# value.contiguous() if value.device.type == "xpu" else value,
621+
# attn_output,
622+
# seq_len_tensor,
623+
# seq_len_tensor,
624+
# input_lens.max(),
625+
# input_lens.max(),
626+
# 0.0,
627+
# 1.0 / math.sqrt(self.head_dim),
628+
# False,
629+
# True,
630+
# False,
631+
# None,
632+
# )
633+
PagedAttention.flash_attn_varlen_func(
621634
attn_output,
635+
query,
636+
key_cache,
637+
value_cache,
622638
seq_len_tensor,
623639
seq_len_tensor,
624640
input_lens.max(),
625641
input_lens.max(),
626-
0.0,
627642
1.0 / math.sqrt(self.head_dim),
628-
False,
629643
True,
630-
False,
644+
past_key_value.block_tables,
631645
None,
632646
)
633647
else:

0 commit comments

Comments
 (0)