|
54 | 54 | _patch_model,
|
55 | 55 | )
|
56 | 56 | from ..utils.constant import _TASK_ALIASES
|
57 |
| -from ..utils.import_utils import is_ipex_version, is_transformers_version |
| 57 | +from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version |
58 | 58 | from ..utils.modeling_utils import recursive_to_device
|
59 | 59 |
|
60 | 60 |
|
|
64 | 64 | _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2")
|
65 | 65 | _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
|
66 | 66 | _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
|
67 |
| -# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 |
68 |
| -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2") |
| 67 | +# Page attention model cannot use torch.compile for now. |
| 68 | +if is_torch_version("<", "2.6"): |
| 69 | + _COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2") |
| 70 | +else: |
| 71 | + _COMPILE_NOT_READY_MODEL_TYPES = ("llama", "falcon", "gpt2", "qwen2") |
69 | 72 |
|
70 | 73 |
|
71 | 74 | def _is_patched_with_ipex(model, task, use_cache: bool = True):
|
|
0 commit comments