Skip to content

Commit 41f0a46

Browse files
sywangyijiqing-fengkaixuanliuIlyasMoutawwakil
authored
unify xpu and cpu backend and use paged attention (huggingface#1009)
* add page attention implementation remove jit logic Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add support in transformers 4.45 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix congif (huggingface#935) * move patch model to init Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine class IPEXPagedCache's update method (huggingface#945) * refine class IPEXPagedCache's update method Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * replace tensor on xpu to List to avoid memory copy Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * split IPEXPagedCache's update function into `update_for_prefill` and `update_for_decode` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix bug when doing beam search (huggingface#954) Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * enable qkv concat layer (huggingface#958) * enable qkv * split key value into 2 lists * add xpu cache optimiztion Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * xpu mlp optimization Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * optimize cache ops in xpu, improve for beam search Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable gpt2, falcon has core dump error in PagedAttention.single_quer… (huggingface#979) * enable gpt2, falcon has core dump error in PagedAttention.single_query_cached_kv_attention * enable new_decoder_arch falcon * only keep 1 config * rm autocast * fix unit test case, CPU part is OK; Enable Falcon7b for XPU (huggingface#992) * fix bug when run IPEXCausalModel forward directly; fix bug when using `save_pretrain` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add LinearGelu Op support for XPU Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix unit test error Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust unit test case Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix bug Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * skip assited decoding unit test for models using paged attention (huggingface#998) * skip assited decoding unit test for models using paged attention Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * XPU CI tests get almost all passed Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix ci config (huggingface#1010) Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix tests versions (huggingface#1011) * fix ci config * fix test versions * fix ipex version Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch test version (huggingface#1012) Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use python3.9 test (huggingface#1013) * use python3.9 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change ipex transformers limited verison in setup (huggingface#1015) * change ipex transformers limited verison in setup * fix inc tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add XPU LinearAddAdd op (huggingface#1017) Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix bert and vit patch (huggingface#1022) * fix bert and vit patch * fix vit and bert save Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Paged attn (huggingface#1024) * fix reorder cache for non-patch models Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * disable torch < 2.3 tests, we won't use torch < 2.4 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix test beam serach Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cache selection Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * upgrad to transformers4.46 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change ipex test yaml transformers version to 4.46 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set device as the same as origin model (huggingface#1031) * set device as the same as origin model * fix device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Simplify IPEXModel (huggingface#1032) * simplify forward and save pretrained since no jit support * fix format * rm warmup because no jit mode anymore * simplify forward for causal lm model * fix paged pkv forward * disable use_cache when just run forward --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * nice code (huggingface#1035) Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Paged attn (huggingface#1036) * nice code * device type adjustment Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Enable torch.compile for non-generation tasks in CPU (huggingface#1037) * enable compile for non-generation tasks * add no_grad in forward * warmup compiled model * disable compile not ready models * set system level optimize for torch.compile * fix typo * add comments * set torch minimum version for compiling Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix ipex upload and update readme. (huggingface#1045) * fix readme and push to hub support Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm export in tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * test with torch 2.5.* Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix tests (huggingface#1047) * fix tests * fix typo * add patched tests * change forward to generate * fix tests * fix test model name --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Patch gpt2 block forward for passing input_lens. (huggingface#1050) * fix forward without pkv * patch gpt2 block forward * fix typo * revert causal lm tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: kaixuanliu <kaixuan.liu@intel.com> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent c94b3f5 commit 41f0a46

File tree

13 files changed

+1035
-860
lines changed

13 files changed

+1035
-860
lines changed

.github/workflows/test_inc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
torch-version: ["2.4.*", "2.5.0"]
21+
torch-version: ["2.4.0", "2.5.*"]
2222

2323
runs-on: ubuntu-22.04
2424

.github/workflows/test_ipex.yml

+2-6
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
torch-version: ["2.2.0", "2.3.*"]
22-
transformers-version: ["4.39.0", "4.44.*"]
21+
transformers-version: ["4.46.0", "4.46.3"]
22+
torch-version: ["2.4.0", "2.5.*"]
2323

2424
runs-on: ubuntu-22.04
2525

@@ -38,10 +38,6 @@ jobs:
3838
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
3939
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}
4040
41-
- if: ${{ matrix.torch-version == '2.2.0' }}
42-
name: Downgrade Numpy
43-
run: pip install numpy==1.*
44-
4541
- name: Assert versions
4642
run: |
4743
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"

docs/source/ipex/inference.mdx

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m
1414

1515
## Loading
1616

17-
You can load your model and apply IPEX optimizations (including weight prepacking and graph mode). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
18-
For now, support is only enabled for CPUs and the original model will be exported via TorchScript. In the future `torch.compile` will be used and model exported via TorchScript will get deprecated.
17+
You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
18+
For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.
1919

2020
```diff
2121
import torch
@@ -25,7 +25,7 @@ For now, support is only enabled for CPUs and the original model will be exporte
2525

2626
model_id = "gpt2"
2727
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
28-
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True)
28+
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
2929
tokenizer = AutoTokenizer.from_pretrained(model_id)
3030
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
3131
results = pipe("He's a dreadful magician and")

optimum/exporters/ipex/cache_utils.py

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
from typing import List, Optional, Tuple
2+
3+
import torch
4+
from intel_extension_for_pytorch.llm.modules import PagedAttention
5+
from transformers import Cache, PretrainedConfig
6+
7+
8+
class IPEXPagedCache(Cache):
9+
"""
10+
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
11+
ipex-xpu:
12+
ipex-cpu:
13+
14+
Example:
15+
16+
```python
17+
>>> from transformers import AutoTokenizer
18+
>>> from optimum.intel import IPEXModelForCausalLM
19+
>>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache
20+
21+
>>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True)
22+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
23+
24+
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
25+
26+
>>> # Prepare a cache class and pass it to model's forward
27+
>>> past_key_values = IPEXPagedCache()
28+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
29+
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
30+
```
31+
"""
32+
33+
def __init__(
34+
self,
35+
config: PretrainedConfig,
36+
batch_size: int,
37+
max_cache_len: int,
38+
device,
39+
dtype=None,
40+
layer_device_map=None,
41+
**kwargs,
42+
) -> None:
43+
super().__init__()
44+
self.batch_size = batch_size
45+
# Used in `generate` to keep tally of how many tokens the cache has seen
46+
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
47+
self.block_size = 16
48+
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
49+
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
50+
batch_size, -1
51+
)
52+
self.free_blocks = torch.arange(self.num_blocks, device=device)
53+
self.max_cache_len = max_cache_len
54+
self.num_kv_heads = config.num_key_value_heads
55+
self.num_hidden_layers = config.num_hidden_layers
56+
if hasattr(config, "head_dim"):
57+
head_size = config.head_dim
58+
else:
59+
head_size = config.hidden_size // config.num_attention_heads
60+
self.head_size = head_size
61+
self.max_seq_len = 0
62+
63+
self.key_cache: List[torch.Tensor] = []
64+
self.value_cache: List[torch.Tensor] = []
65+
66+
if device.type == "cpu":
67+
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
68+
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
69+
elif device.type == "xpu":
70+
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
71+
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
72+
for i in range(config.num_hidden_layers):
73+
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
74+
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
75+
self.key_cache.append(new_layer_key_cache)
76+
self.value_cache.append(new_layer_value_cache)
77+
78+
def update_for_prefill(
79+
self,
80+
key_states: torch.Tensor,
81+
value_states: torch.Tensor,
82+
layer_idx: int,
83+
batch_size: int,
84+
input_lens: torch.Tensor,
85+
):
86+
if layer_idx == 0:
87+
all_block_indices = []
88+
all_slot_offsets = []
89+
num_blocks = (input_lens + self.block_size - 1) // self.block_size
90+
for i in range(batch_size):
91+
for b_idx in range(num_blocks[i]):
92+
if self.block_tables[i][b_idx] == -1:
93+
# need a free block
94+
self.block_tables[i][b_idx] = self.free_blocks[0]
95+
self.free_blocks = self.free_blocks[1:]
96+
97+
slots_range = torch.arange(input_lens[i], device=key_states.device)
98+
block_indices = slots_range // self.block_size
99+
slot_offsets = slots_range % self.block_size
100+
all_block_indices.append(self.block_tables[i][block_indices])
101+
all_slot_offsets.append(slot_offsets)
102+
103+
all_block_indices = torch.cat(all_block_indices)
104+
all_slot_offsets = torch.cat(all_slot_offsets)
105+
self.slots = all_block_indices * self.block_size + all_slot_offsets
106+
107+
# Update the cache
108+
PagedAttention.reshape_and_cache(
109+
key_states,
110+
value_states,
111+
self.key_cache[layer_idx],
112+
self.value_cache[layer_idx],
113+
self.slots,
114+
)
115+
116+
# Update the number of seen tokens
117+
if layer_idx == self.num_hidden_layers - 1:
118+
self._seen_tokens = self._seen_tokens + input_lens
119+
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
120+
121+
def update_for_decode(
122+
self,
123+
key_states: torch.Tensor,
124+
value_states: torch.Tensor,
125+
layer_idx: int,
126+
batch_size: int,
127+
):
128+
if layer_idx == 0:
129+
start_block_idx = self._seen_tokens // self.block_size
130+
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
131+
slot_offset_in_block = (self._seen_tokens) % self.block_size
132+
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
133+
for i in range(batch_size):
134+
for b_idx in range(start_block_idx[i], num_blocks[i]):
135+
if self.block_tables[i][b_idx] == -1:
136+
# need a free block
137+
self.block_tables[i][b_idx] = self.free_blocks[0]
138+
self.free_blocks = self.free_blocks[1:]
139+
140+
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
141+
# Update the cache
142+
PagedAttention.reshape_and_cache(
143+
key_states,
144+
value_states,
145+
self.key_cache[layer_idx],
146+
self.value_cache[layer_idx],
147+
self.slots,
148+
)
149+
150+
# Update the number of seen tokens
151+
if layer_idx == self.num_hidden_layers - 1:
152+
self._seen_tokens = self._seen_tokens + 1
153+
self.max_seq_len = self.max_seq_len + 1
154+
155+
def update(
156+
self,
157+
key_states: torch.Tensor,
158+
value_states: torch.Tensor,
159+
layer_idx: int,
160+
attention_mask: torch.Tensor,
161+
input_lens: torch.Tensor,
162+
) -> Tuple[torch.Tensor, torch.Tensor]:
163+
"""
164+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
165+
166+
Parameters:
167+
key_states (`torch.Tensor`):
168+
The new key states to cache.
169+
value_states (`torch.Tensor`):
170+
The new value states to cache.
171+
layer_idx (`int`):
172+
The index of the layer to cache the states for.
173+
Return:
174+
A tuple containing the updated key and value states.
175+
"""
176+
177+
batch_size = input_lens.shape[-1]
178+
if self.get_seq_length() == 0:
179+
# prefill
180+
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)
181+
else:
182+
# decode
183+
self.update_for_decode(key_states, value_states, layer_idx, batch_size)
184+
185+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
186+
187+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
188+
"""Returns the sequence length of the cached states that were seen by the model."""
189+
return self.max_seq_len
190+
191+
def get_max_length(self) -> Optional[int]:
192+
"""Returns the maximum sequence length of the cached states."""
193+
return self.max_cache_len
194+
195+
def reset(self):
196+
"""Resets the cache values while preserving the objects"""
197+
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
198+
self.block_tables.fill_(-1)
199+
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
200+
self.max_seq_len = 0
201+
202+
def reorder_cache(self, beam_idx: torch.LongTensor):
203+
"""Reorders the cache for beam search, given the selected beam indices."""
204+
device = self.block_tables.device
205+
origin_table = self.block_tables.clone()
206+
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
207+
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
208+
num_blocks = mask.cumsum(-1)[:, -1]
209+
updated_table = []
210+
for i in range(beam_idx.shape[0]):
211+
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
212+
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
213+
updated_table = torch.cat(tuple(updated_table), dim=0)
214+
for layer_idx in range(self.num_hidden_layers):
215+
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
216+
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
217+
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
218+
self.free_blocks = torch.cat((self.free_blocks, free_table))
219+
220+
def crop(self, maximum_length: int):
221+
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
222+
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
223+
224+
max_seq_len = self.get_seq_length()
225+
if maximum_length < 0:
226+
maximum_length = max_seq_len - abs(maximum_length)
227+
228+
if max_seq_len <= maximum_length:
229+
return
230+
origin_table = self.block_tables.clone()
231+
for bs in range(self._seen_tokens.shape[0]):
232+
new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len
233+
num_blocks = (new_tokens + self.block_size - 1) // self.block_size
234+
self.block_tables[bs, num_blocks:] = -1
235+
self._seen_tokens[bs] = new_tokens
236+
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
237+
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
238+
self.free_blocks = torch.cat((self.free_blocks, free_table))

optimum/exporters/ipex/model_patcher.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414

1515
from transformers.models.bert.modeling_bert import BertIntermediate
16-
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
17-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel
16+
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
17+
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
1818
from transformers.models.llama.modeling_llama import (
1919
LlamaDecoderLayer,
20-
LlamaForCausalLM,
2120
LlamaModel,
2221
LlamaRMSNorm,
2322
)
@@ -28,7 +27,9 @@
2827

2928
from .modeling_utils import (
3029
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
30+
_falcon_model_forward,
3131
_gpt2_block_forward,
32+
_gpt2_model_forward,
3233
_ipex_rms_layer_norm_forward,
3334
_IPEXFalconDecoderLayer,
3435
_IPEXGPT2Attention,
@@ -39,8 +40,8 @@
3940

4041

4142
# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
42-
_TRANSFORMERS_MIN_VERSION = "4.39.0"
43-
_TRANSFORMERS_MAX_VERSION = "4.44.99"
43+
_TRANSFORMERS_MIN_VERSION = "4.46.0"
44+
_TRANSFORMERS_MAX_VERSION = "4.46.99"
4445

4546
_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)
4647

@@ -75,7 +76,7 @@ def patch_op(m, target_m, new_op_name, new_op):
7576
def _patch_llama_model(model):
7677
"""
7778
Patch llama model:
78-
1. Use IPEX Rope and IAKV cache
79+
1. Use IPEX rope and paged cache
7980
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
8081
"""
8182
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
@@ -87,11 +88,14 @@ def _patch_llama_model(model):
8788
def _patch_falcon_model(model):
8889
"""
8990
Patch falcon model:
90-
1. Disable SDPA so the attention mask will be compatible to ipex attention.
91-
2. Use IPEX Rope and IAKV cache
92-
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
91+
1. Use IPEX rope and paged cache
92+
2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
9393
"""
94-
model.transformer._use_sdpa = False
94+
num_key_value_heads = (
95+
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
96+
)
97+
setattr(model.config, "num_key_value_heads", num_key_value_heads)
98+
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
9599
replace_customized_linear_with_linear(model)
96100
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
97101
return model
@@ -100,12 +104,13 @@ def _patch_falcon_model(model):
100104
def _patch_gpt2_model(model):
101105
"""
102106
Patch gpt2 model:
103-
1. Disable SDPA so the attention mask will be compatible to ipex attention.
104-
2. Use IAKV cache
107+
1. Use IPEX paged attention
105108
"""
106-
model.transformer._attn_implementation = "eager"
107-
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
109+
num_key_value_heads = model.config.num_attention_heads
110+
setattr(model.config, "num_key_value_heads", num_key_value_heads)
111+
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
108112
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
113+
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
109114
return model
110115

111116

@@ -136,11 +141,11 @@ def _patch_model(model):
136141
raise ImportError(
137142
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
138143
)
139-
if isinstance(model, LlamaForCausalLM):
144+
if model.config.model_type == "llama":
140145
model = _patch_llama_model(model)
141-
elif isinstance(model, FalconForCausalLM):
146+
elif model.config.model_type == "falcon":
142147
model = _patch_falcon_model(model)
143-
elif isinstance(model, GPT2LMHeadModel):
148+
elif model.config.model_type == "gpt2":
144149
model = _patch_gpt2_model(model)
145150
elif model.config.model_type == "bert":
146151
model = _patch_bert_model(model)

0 commit comments

Comments
 (0)