Skip to content

Commit 6186aaf

Browse files
committed
use varlen if flash attn not available
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent fb71c2e commit 6186aaf

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

.github/workflows/test_ipex.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
transformers-version: ["4.46.0", "4.46.3"]
22-
torch-version: ["2.5.*"]
22+
torch-version: ["2.4.0", "2.5.*"]
2323

2424
runs-on: ubuntu-22.04
2525

optimum/exporters/ipex/modeling_utils.py

+42-17
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,23 @@
2424
)
2525
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
2626

27-
from optimum.intel.utils.import_utils import is_ipex_version
27+
from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version
2828
from optimum.intel.utils.modeling_utils import _setattr_from_module
2929

3030
from .cache_utils import IPEXPagedCache
3131

3232

3333
logger = logging.getLogger(__name__)
3434

35-
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0"
35+
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
3636

3737

3838
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
3939
logger.warning(
4040
f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
4141
)
4242
else:
43-
from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding
43+
from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention
4444
from intel_extension_for_pytorch.llm.modules import (
4545
Linear2SiluMul,
4646
LinearAdd,
@@ -627,24 +627,49 @@ def postprocess_attention_output(self, attn_output):
627627
attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
628628
return attn_output
629629

630+
# Maybe removed after torch 2.6 released
631+
def has_flash_attn(query):
632+
if query.device.type == "cpu":
633+
return is_torch_version(">", "2.4.99")
634+
elif query.device.type == "xpu":
635+
return is_torch_version(">", "2.5.99")
636+
630637
def varlen_attn(self, query, key, value, past_key_value, input_lens):
631638
# prefill, remove padding
632639
attn_output = torch.empty_like(query)
633640
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
634-
PagedAttention.flash_attn_varlen_func(
635-
attn_output,
636-
query,
637-
key,
638-
value,
639-
seq_len_tensor,
640-
seq_len_tensor,
641-
input_lens.max(),
642-
input_lens.max(),
643-
1.0 / math.sqrt(self.head_dim),
644-
True,
645-
past_key_value.block_tables,
646-
None,
647-
)
641+
if self.has_flash_attn(query):
642+
PagedAttention.flash_attn_varlen_func(
643+
attn_output,
644+
query,
645+
key,
646+
value,
647+
seq_len_tensor,
648+
seq_len_tensor,
649+
input_lens.max(),
650+
input_lens.max(),
651+
1.0 / math.sqrt(self.head_dim),
652+
True,
653+
past_key_value.block_tables,
654+
None,
655+
)
656+
else:
657+
varlen_attention(
658+
query.contiguous() if query.device.type == "xpu" else query,
659+
key.contiguous() if key.device.type == "xpu" else key,
660+
value.contiguous() if value.device.type == "xpu" else value,
661+
attn_output,
662+
seq_len_tensor,
663+
seq_len_tensor,
664+
input_lens.max(),
665+
input_lens.max(),
666+
0.0,
667+
1.0 / math.sqrt(self.head_dim),
668+
False,
669+
True,
670+
False,
671+
None,
672+
)
648673

649674
return attn_output
650675

0 commit comments

Comments
 (0)