Skip to content

Commit 372d3f8

Browse files
committed
prefill attn
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 4dd2e44 commit 372d3f8

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

optimum/exporters/ipex/modeling_utils.py

+28-24
Original file line numberDiff line numberDiff line change
@@ -634,16 +634,31 @@ def has_flash_attn(self, query):
634634
elif query.device.type == "xpu":
635635
return is_torch_version(">", "2.5.99")
636636

637-
def varlen_attn(self, query, key, value, past_key_value, input_lens):
638-
# prefill, remove padding
639-
attn_output = torch.empty_like(query)
640-
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
641-
if self.has_flash_attn(query):
637+
def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens):
638+
if past_key_value is None:
639+
n_rep = query.shape[1] // key.shape[1]
640+
attn_output = torch.nn.functional.scaled_dot_product_attention(
641+
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
642+
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
643+
.transpose(1, 2)
644+
.repeat_interleave(n_rep, 1),
645+
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
646+
.transpose(1, 2)
647+
.repeat_interleave(n_rep, 1),
648+
attn_mask=attention_mask,
649+
dropout_p=0.0,
650+
is_causal=True,
651+
)
652+
self.use_sdpa = True
653+
elif self.has_flash_attn(query):
654+
# prefill, remove padding
655+
attn_output = torch.empty_like(query)
656+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
642657
PagedAttention.flash_attn_varlen_func(
643658
attn_output,
644659
query,
645-
key,
646-
value,
660+
key_cache,
661+
value_cache,
647662
seq_len_tensor,
648663
seq_len_tensor,
649664
input_lens.max(),
@@ -654,6 +669,9 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens):
654669
None,
655670
)
656671
else:
672+
# prefill, remove padding
673+
attn_output = torch.empty_like(query)
674+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
657675
varlen_attention(
658676
query.contiguous() if query.device.type == "xpu" else query,
659677
key.contiguous() if key.device.type == "xpu" else key,
@@ -697,23 +715,9 @@ def forward(
697715

698716
if past_len == 0:
699717
# prefill
700-
if past_key_value is None:
701-
n_rep = query.shape[1] // key.shape[1]
702-
attn_output = torch.nn.functional.scaled_dot_product_attention(
703-
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
704-
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
705-
.transpose(1, 2)
706-
.repeat_interleave(n_rep, 1),
707-
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
708-
.transpose(1, 2)
709-
.repeat_interleave(n_rep, 1),
710-
attn_mask=attention_mask,
711-
dropout_p=0.0,
712-
is_causal=True,
713-
)
714-
self.use_sdpa = True
715-
else:
716-
attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens)
718+
attn_output = self.prefill_attn(
719+
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens
720+
)
717721
else:
718722
# decode
719723
attn_output = torch.empty_like(query)

0 commit comments

Comments
 (0)