Skip to content

Commit c6d2d0f

Browse files
committed
decoding use single query
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 12dd802 commit c6d2d0f

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

optimum/exporters/ipex/modeling_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -652,19 +652,17 @@ def attention_interface(
652652
is_causal=True,
653653
)
654654
self.use_sdpa = True
655-
elif self.has_flash_attn(query):
655+
elif self.has_flash_attn(query) and past_len == 0:
656656
attn_output = torch.empty_like(query)
657657
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
658-
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
659-
query_max_len = input_lens.max() if past_len == 0 else 1
660658
PagedAttention.flash_attn_varlen_func(
661659
attn_output,
662660
query.contiguous() if query.device.type == "xpu" else query,
663661
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
664662
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
665-
query_len_tensor,
666663
seq_len_tensor,
667-
query_max_len,
664+
seq_len_tensor,
665+
input_lens.max(),
668666
input_lens.max(),
669667
1.0 / math.sqrt(self.head_dim),
670668
True,

0 commit comments

Comments
 (0)