Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade transformers to 4.49 for patching models #1196

Merged
merged 3 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
transformers-version: ["4.47.*"]
transformers-version: ["4.49.0"]
torch-version: ["2.6.0"]

runs-on: ubuntu-22.04
Expand Down
6 changes: 4 additions & 2 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import List, Optional, Tuple

import intel_extension_for_pytorch as ipex
import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig
Expand Down Expand Up @@ -38,13 +39,14 @@ def __init__(
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: int,
device,
device=None,
dtype=None,
layer_device_map=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems more like a global variable, also if I understand correctly, if xpu device is available and ipex has xpu, it's okay, but what if xpu device is available and current ipex doesn't have xpu support, maybe warning the user and then proceding with cpu device makes sense.

device = device or default_device
self.device = device
self._supports_flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.47.0"
_TRANSFORMERS_MAX_VERSION = "4.47.99"
_TRANSFORMERS_MIN_VERSION = "4.49.0"
_TRANSFORMERS_MAX_VERSION = "4.49.0"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down
25 changes: 14 additions & 11 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _falcon_model_forward(

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
# attention_probs has shape batch_size x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = inputs_embeds

Expand Down Expand Up @@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
_setattr_from_module(self, module)
self.config = config
self.module_device = device
self.num_groups = self.num_heads // self.num_key_value_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_attention_heads = config.num_attention_heads
self.num_groups = self.num_attention_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
).repeat_interleave(self.num_groups)
Expand Down Expand Up @@ -894,11 +896,11 @@ def __init__(self, module, device, config) -> None:
def qkv_gemm(self, hidden_states):
if hasattr(self, "concat_qkv"):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
else:
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
query = self.q_proj(hidden_states).view(-1, self.num_attention_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)

Expand All @@ -916,20 +918,21 @@ def __init__(self, module, device, config):
def qkv_gemm(self, hidden_states):
qkv_out = self.query_key_value(hidden_states)
if self.new_decoder_architecture:
qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
qkv_out = qkv_out.view(
qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim
)
query = qkv_out[:, :, :-2, :].flatten(1, 2)
key = qkv_out[:, :, [-2], :].flatten(1, 2)
value = qkv_out[:, :, [-1], :].flatten(1, 2)
else:
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
return query, key, value


class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, device, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, device, config)
_setattr_from_module(self, module)
if getattr(config, "quantization_config", None) is None:
Expand All @@ -952,9 +955,9 @@ def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
query = query.view(-1, self.num_heads, self.head_dim)
key = key.view(-1, self.num_heads, self.head_dim)
value = value.view(-1, self.num_heads, self.head_dim)
query = query.view(-1, self.num_attention_heads, self.head_dim)
key = key.view(-1, self.num_attention_heads, self.head_dim)
value = value.view(-1, self.num_attention_heads, self.head_dim)
return query, key, value

def rope(self, query, key, *args, **kwargs):
Expand Down
11 changes: 11 additions & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@
_COMPILE_NOT_READY_MODEL_TYPES = ("llama", "falcon", "gpt2", "qwen2")


try:
import intel_extension_for_pytorch as ipex

if hasattr(torch, "xpu") and torch.xpu.is_available() and not ipex._C._has_xpu():
logger.warning(
"Detect you have XPU device but the ipex do not support XPU, please install a xpu version ipex by checking https://pytorch-extension.intel.com/installation?platform=gpu"
)
except ImportError:
logger.warning("No intel_extension_for_pytorch found, please `pip install intel_extension_for_pytorch`")


def _is_patched_with_ipex(model, task, use_cache: bool = True):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
return False
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"nncf": ["nncf>=2.14.0"],
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.46,<4.48", "accelerate"],
"ipex": ["intel-extension-for-pytorch>=2.6", "transformers>4.48,<4.50", "accelerate"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
Loading