|
24 | 24 | )
|
25 | 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
|
26 | 26 |
|
27 |
| -from optimum.intel.utils.import_utils import is_ipex_version |
| 27 | +from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version |
28 | 28 | from optimum.intel.utils.modeling_utils import _setattr_from_module
|
29 | 29 |
|
30 | 30 | from .cache_utils import IPEXPagedCache
|
31 | 31 |
|
32 | 32 |
|
33 | 33 | logger = logging.getLogger(__name__)
|
34 | 34 |
|
35 |
| -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0" |
| 35 | +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" |
36 | 36 |
|
37 | 37 |
|
38 | 38 | if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
|
39 | 39 | logger.warning(
|
40 | 40 | f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model."
|
41 | 41 | )
|
42 | 42 | else:
|
43 |
| - from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding |
| 43 | + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention |
44 | 44 | from intel_extension_for_pytorch.llm.modules import (
|
45 | 45 | Linear2SiluMul,
|
46 | 46 | LinearAdd,
|
@@ -627,24 +627,49 @@ def postprocess_attention_output(self, attn_output):
|
627 | 627 | attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1])
|
628 | 628 | return attn_output
|
629 | 629 |
|
| 630 | + # Maybe removed after torch 2.6 released |
| 631 | + def has_flash_attn(query): |
| 632 | + if query.device.type == "cpu": |
| 633 | + return is_torch_version(">", "2.4.99") |
| 634 | + elif query.device.type == "xpu": |
| 635 | + return is_torch_version(">", "2.5.99") |
| 636 | + |
630 | 637 | def varlen_attn(self, query, key, value, past_key_value, input_lens):
|
631 | 638 | # prefill, remove padding
|
632 | 639 | attn_output = torch.empty_like(query)
|
633 | 640 | seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
|
634 |
| - PagedAttention.flash_attn_varlen_func( |
635 |
| - attn_output, |
636 |
| - query, |
637 |
| - key, |
638 |
| - value, |
639 |
| - seq_len_tensor, |
640 |
| - seq_len_tensor, |
641 |
| - input_lens.max(), |
642 |
| - input_lens.max(), |
643 |
| - 1.0 / math.sqrt(self.head_dim), |
644 |
| - True, |
645 |
| - past_key_value.block_tables, |
646 |
| - None, |
647 |
| - ) |
| 641 | + if self.has_flash_attn(query): |
| 642 | + PagedAttention.flash_attn_varlen_func( |
| 643 | + attn_output, |
| 644 | + query, |
| 645 | + key, |
| 646 | + value, |
| 647 | + seq_len_tensor, |
| 648 | + seq_len_tensor, |
| 649 | + input_lens.max(), |
| 650 | + input_lens.max(), |
| 651 | + 1.0 / math.sqrt(self.head_dim), |
| 652 | + True, |
| 653 | + past_key_value.block_tables, |
| 654 | + None, |
| 655 | + ) |
| 656 | + else: |
| 657 | + varlen_attention( |
| 658 | + query.contiguous() if query.device.type == "xpu" else query, |
| 659 | + key.contiguous() if key.device.type == "xpu" else key, |
| 660 | + value.contiguous() if value.device.type == "xpu" else value, |
| 661 | + attn_output, |
| 662 | + seq_len_tensor, |
| 663 | + seq_len_tensor, |
| 664 | + input_lens.max(), |
| 665 | + input_lens.max(), |
| 666 | + 0.0, |
| 667 | + 1.0 / math.sqrt(self.head_dim), |
| 668 | + False, |
| 669 | + True, |
| 670 | + False, |
| 671 | + None, |
| 672 | + ) |
648 | 673 |
|
649 | 674 | return attn_output
|
650 | 675 |
|
|
0 commit comments