Skip to content

Commit e001be9

Browse files
committed
enable gpt2, falcon has core dump error in PagedAttention.single_query_cached_kv_attention
1 parent 45130c9 commit e001be9

File tree

3 files changed

+293
-49
lines changed

3 files changed

+293
-49
lines changed

optimum/exporters/ipex/cache_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def update(
158158
value_states: torch.Tensor,
159159
layer_idx: int,
160160
attention_mask: torch.Tensor,
161-
position_ids: torch.Tensor,
162161
input_lens: torch.Tensor,
163162
) -> Tuple[torch.Tensor, torch.Tensor]:
164163
"""
@@ -175,7 +174,7 @@ def update(
175174
A tuple containing the updated key and value states.
176175
"""
177176

178-
batch_size = position_ids.shape[0]
177+
batch_size = input_lens.shape[-1]
179178
if self.get_seq_length() == 0:
180179
# prefill
181180
self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens)

optimum/exporters/ipex/model_patcher.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
from transformers.models.bert.modeling_bert import BertIntermediate
16-
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer
17-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
16+
from transformers.models.falcon.modeling_falcon import FalconModel, FalconDecoderLayer
17+
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
1818
from transformers.models.llama.modeling_llama import (
1919
LlamaDecoderLayer,
2020
LlamaModel,
@@ -27,13 +27,14 @@
2727

2828
from .modeling_utils import (
2929
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
30-
_gpt2_block_forward,
3130
_ipex_rms_layer_norm_forward,
3231
_IPEXFalconDecoderLayer,
3332
_IPEXGPT2Attention,
3433
_IPEXIntermediate,
3534
_IPEXLlamaDecoderLayer,
3635
_llama_model_forward,
36+
_falcon_model_forward,
37+
_gpt2_model_forward,
3738
)
3839

3940

@@ -90,7 +91,9 @@ def _patch_falcon_model(model):
9091
2. Use IPEX Rope and paged cache
9192
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
9293
"""
93-
model.transformer._use_sdpa = False
94+
num_key_value_heads = model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
95+
setattr(model.config, "num_key_value_heads", num_key_value_heads)
96+
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
9497
replace_customized_linear_with_linear(model)
9598
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
9699
return model
@@ -102,9 +105,10 @@ def _patch_gpt2_model(model):
102105
1. Disable SDPA so the attention mask will be compatible to ipex attention.
103106
2. Use IAKV cache
104107
"""
105-
model.transformer._attn_implementation = "eager"
108+
num_key_value_heads = model.config.num_attention_heads
109+
setattr(model.config, "num_key_value_heads", num_key_value_heads)
110+
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
106111
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
107-
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
108112
return model
109113

110114

0 commit comments

Comments
 (0)