Skip to content

Commit 160f65c

Browse files
committed
upgrade transformers to 4.49 for patching models
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 7f70f2b commit 160f65c

File tree

4 files changed

+21
-16
lines changed

4 files changed

+21
-16
lines changed

.github/workflows/test_ipex.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
transformers-version: ["4.47.*"]
21+
transformers-version: ["4.49.0"]
2222
torch-version: ["2.6.0"]
2323

2424
runs-on: ubuntu-22.04

optimum/exporters/ipex/cache_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import List, Optional, Tuple
33

4+
import intel_extension_for_pytorch as ipex
45
import torch
56
from intel_extension_for_pytorch.llm.modules import PagedAttention
67
from transformers import Cache, PretrainedConfig
@@ -38,13 +39,14 @@ def __init__(
3839
config: PretrainedConfig,
3940
max_batch_size: int,
4041
max_cache_len: int,
41-
device,
42+
device=None,
4243
dtype=None,
43-
layer_device_map=None,
4444
**kwargs,
4545
) -> None:
4646
super().__init__()
4747
self.max_batch_size = max_batch_size
48+
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
49+
device = device or default_device
4850
self.device = device
4951
self._supports_flash_decoding = (
5052
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")

optimum/exporters/ipex/model_patcher.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646

4747

4848
# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
49-
_TRANSFORMERS_MIN_VERSION = "4.47.0"
50-
_TRANSFORMERS_MAX_VERSION = "4.47.99"
49+
_TRANSFORMERS_MIN_VERSION = "4.49.0"
50+
_TRANSFORMERS_MAX_VERSION = "4.49.0"
5151

5252
_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)
5353

optimum/exporters/ipex/modeling_utils.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ def _falcon_model_forward(
346346

347347
# Prepare head mask if needed
348348
# 1.0 in head_mask indicate we keep the head
349-
# attention_probs has shape batch_size x num_heads x N x N
350-
# head_mask has shape n_layer x batch x num_heads x N x N
349+
# attention_probs has shape batch_size x num_attention_heads x N x N
350+
# head_mask has shape n_layer x batch x num_attention_heads x N x N
351351
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
352352
hidden_states = inputs_embeds
353353

@@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
707707
_setattr_from_module(self, module)
708708
self.config = config
709709
self.module_device = device
710-
self.num_groups = self.num_heads // self.num_key_value_heads
710+
self.num_key_value_heads = config.num_key_value_heads
711+
self.num_attention_heads = config.num_attention_heads
712+
self.num_groups = self.num_attention_heads // self.num_key_value_heads
711713
self.kv_head_mapping = torch.arange(
712714
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
713715
).repeat_interleave(self.num_groups)
@@ -894,11 +896,11 @@ def __init__(self, module, device, config) -> None:
894896
def qkv_gemm(self, hidden_states):
895897
if hasattr(self, "concat_qkv"):
896898
qkv_out = self.concat_qkv(hidden_states)
897-
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
899+
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
898900
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
899901
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
900902
else:
901-
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
903+
query = self.q_proj(hidden_states).view(-1, self.num_attention_heads, self.head_dim)
902904
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
903905
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
904906

@@ -916,20 +918,21 @@ def __init__(self, module, device, config):
916918
def qkv_gemm(self, hidden_states):
917919
qkv_out = self.query_key_value(hidden_states)
918920
if self.new_decoder_architecture:
919-
qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
921+
qkv_out = qkv_out.view(
922+
qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim
923+
)
920924
query = qkv_out[:, :, :-2, :].flatten(1, 2)
921925
key = qkv_out[:, :, [-2], :].flatten(1, 2)
922926
value = qkv_out[:, :, [-1], :].flatten(1, 2)
923927
else:
924-
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
928+
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
925929
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
926930
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
927931
return query, key, value
928932

929933

930934
class _IPEXGPT2Attention(_IPEXAttention):
931935
def __init__(self, module, device, config) -> None:
932-
self.num_key_value_heads = config.num_key_value_heads
933936
super().__init__(module, device, config)
934937
_setattr_from_module(self, module)
935938
if getattr(config, "quantization_config", None) is None:
@@ -952,9 +955,9 @@ def qkv_gemm(self, hidden_states):
952955
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
953956
else:
954957
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
955-
query = query.view(-1, self.num_heads, self.head_dim)
956-
key = key.view(-1, self.num_heads, self.head_dim)
957-
value = value.view(-1, self.num_heads, self.head_dim)
958+
query = query.view(-1, self.num_attention_heads, self.head_dim)
959+
key = key.view(-1, self.num_attention_heads, self.head_dim)
960+
value = value.view(-1, self.num_attention_heads, self.head_dim)
958961
return query, key, value
959962

960963
def rope(self, query, key, *args, **kwargs):

0 commit comments

Comments
 (0)