Skip to content

Commit 71aa6b0

Browse files
committed
use flash attn for decode
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 95b7043 commit 71aa6b0

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

optimum/exporters/ipex/modeling_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -652,18 +652,19 @@ def attention_interface(
652652
is_causal=True,
653653
)
654654
self.use_sdpa = True
655-
elif self.has_flash_attn(query) and past_len == 0:
656-
# prefill, remove padding
655+
elif self.has_flash_attn(query):
657656
attn_output = torch.empty_like(query)
658657
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
658+
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0])
659+
query_max_len = input_lens.max() if past_len == 0 else 1
659660
PagedAttention.flash_attn_varlen_func(
660661
attn_output,
661662
query.contiguous() if query.device.type == "xpu" else query,
662663
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
663664
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
665+
query_len_tensor,
664666
seq_len_tensor,
665-
seq_len_tensor,
666-
input_lens.max(),
667+
query_max_len,
667668
input_lens.max(),
668669
1.0 / math.sqrt(self.head_dim),
669670
True,

0 commit comments

Comments
 (0)