Skip to content

Commit a27f0d1

Browse files
committed
Merge branch 'upgrade' into patch
2 parents d7af7ba + 1b0dc0d commit a27f0d1

File tree

5 files changed

+22
-17
lines changed

5 files changed

+22
-17
lines changed

.github/workflows/test_ipex.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
transformers-version: ["4.47.*"]
21+
transformers-version: ["4.49.0"]
2222
torch-version: ["2.6.0"]
2323

2424
runs-on: ubuntu-22.04

optimum/exporters/ipex/cache_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import List, Optional, Tuple
33

4+
import intel_extension_for_pytorch as ipex
45
import torch
56
from intel_extension_for_pytorch.llm.modules import PagedAttention
67
from transformers import Cache, PretrainedConfig
@@ -38,13 +39,14 @@ def __init__(
3839
config: PretrainedConfig,
3940
max_batch_size: int,
4041
max_cache_len: int,
41-
device,
42+
device=None,
4243
dtype=None,
43-
layer_device_map=None,
4444
**kwargs,
4545
) -> None:
4646
super().__init__()
4747
self.max_batch_size = max_batch_size
48+
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
49+
device = device or default_device
4850
self.device = device
4951
self._supports_flash_decoding = (
5052
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")

optimum/exporters/ipex/model_patcher.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646

4747

4848
# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
49-
_TRANSFORMERS_MIN_VERSION = "4.47.0"
50-
_TRANSFORMERS_MAX_VERSION = "4.47.99"
49+
_TRANSFORMERS_MIN_VERSION = "4.49.0"
50+
_TRANSFORMERS_MAX_VERSION = "4.49.0"
5151

5252
_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)
5353

optimum/exporters/ipex/modeling_utils.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ def _falcon_model_forward(
346346

347347
# Prepare head mask if needed
348348
# 1.0 in head_mask indicate we keep the head
349-
# attention_probs has shape batch_size x num_heads x N x N
350-
# head_mask has shape n_layer x batch x num_heads x N x N
349+
# attention_probs has shape batch_size x num_attention_heads x N x N
350+
# head_mask has shape n_layer x batch x num_attention_heads x N x N
351351
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
352352
hidden_states = inputs_embeds
353353

@@ -707,7 +707,9 @@ def __init__(self, module, device, config) -> None:
707707
_setattr_from_module(self, module)
708708
self.config = config
709709
self.module_device = device
710-
self.num_groups = self.num_heads // self.num_key_value_heads
710+
self.num_key_value_heads = config.num_key_value_heads
711+
self.num_attention_heads = config.num_attention_heads
712+
self.num_groups = self.num_attention_heads // self.num_key_value_heads
711713
self.kv_head_mapping = torch.arange(
712714
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
713715
).repeat_interleave(self.num_groups)
@@ -892,11 +894,11 @@ def __init__(self, module, device, config) -> None:
892894
def qkv_gemm(self, hidden_states):
893895
if hasattr(self, "concat_qkv"):
894896
qkv_out = self.concat_qkv(hidden_states)
895-
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
897+
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
896898
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
897899
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
898900
else:
899-
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
901+
query = self.q_proj(hidden_states).view(-1, self.num_attention_heads, self.head_dim)
900902
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
901903
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
902904

@@ -914,20 +916,21 @@ def __init__(self, module, device, config):
914916
def qkv_gemm(self, hidden_states):
915917
qkv_out = self.query_key_value(hidden_states)
916918
if self.new_decoder_architecture:
917-
qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
919+
qkv_out = qkv_out.view(
920+
qkv_out.shape[0], -1, self.num_attention_heads // self.num_kv_heads + 2, self.head_dim
921+
)
918922
query = qkv_out[:, :, :-2, :].flatten(1, 2)
919923
key = qkv_out[:, :, [-2], :].flatten(1, 2)
920924
value = qkv_out[:, :, [-1], :].flatten(1, 2)
921925
else:
922-
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
926+
query = qkv_out[:, : self.q_slice].view(-1, self.num_attention_heads, self.head_dim)
923927
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
924928
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
925929
return query, key, value
926930

927931

928932
class _IPEXGPT2Attention(_IPEXAttention):
929933
def __init__(self, module, device, config) -> None:
930-
self.num_key_value_heads = config.num_key_value_heads
931934
super().__init__(module, device, config)
932935
_setattr_from_module(self, module)
933936
if not config.compile and getattr(config, "quantization_config", None) is None:
@@ -950,9 +953,9 @@ def qkv_gemm(self, hidden_states):
950953
query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1)
951954
else:
952955
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
953-
query = query.view(-1, self.num_heads, self.head_dim)
954-
key = key.view(-1, self.num_heads, self.head_dim)
955-
value = value.view(-1, self.num_heads, self.head_dim)
956+
query = query.view(-1, self.num_attention_heads, self.head_dim)
957+
key = key.view(-1, self.num_attention_heads, self.head_dim)
958+
value = value.view(-1, self.num_attention_heads, self.head_dim)
956959
return query, key, value
957960

958961
def rope(self, query, key, *args, **kwargs):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"nncf": ["nncf>=2.14.0"],
6868
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
6969
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
70-
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.46,<4.48", "accelerate"],
70+
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.48,<4.50", "accelerate"],
7171
"diffusers": ["diffusers"],
7272
"quality": QUALITY_REQUIRE,
7373
"tests": TESTS_REQUIRE,

0 commit comments

Comments
 (0)