13
13
# limitations under the License.
14
14
15
15
from transformers .models .bert .modeling_bert import BertIntermediate
16
- from transformers .models .falcon .modeling_falcon import FalconDecoderLayer
17
- from transformers .models .gpt2 .modeling_gpt2 import GPT2Attention , GPT2Block
16
+ from transformers .models .falcon .modeling_falcon import FalconModel , FalconDecoderLayer
17
+ from transformers .models .gpt2 .modeling_gpt2 import GPT2Attention , GPT2Block , GPT2Model
18
18
from transformers .models .llama .modeling_llama import (
19
19
LlamaDecoderLayer ,
20
20
LlamaModel ,
27
27
28
28
from .modeling_utils import (
29
29
_IPEX_MINIMUM_VERSION_FOR_PATCHING ,
30
- _gpt2_block_forward ,
31
30
_ipex_rms_layer_norm_forward ,
32
31
_IPEXFalconDecoderLayer ,
33
32
_IPEXGPT2Attention ,
34
33
_IPEXIntermediate ,
35
34
_IPEXLlamaDecoderLayer ,
36
35
_llama_model_forward ,
36
+ _falcon_model_forward ,
37
+ _gpt2_model_forward ,
37
38
)
38
39
39
40
@@ -90,7 +91,9 @@ def _patch_falcon_model(model):
90
91
2. Use IPEX Rope and paged cache
91
92
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
92
93
"""
93
- model .transformer ._use_sdpa = False
94
+ num_key_value_heads = model .config .num_kv_heads if (model .config .new_decoder_architecture or not model .config .multi_query ) else 1
95
+ setattr (model .config , "num_key_value_heads" , num_key_value_heads )
96
+ convert_functions (model , FalconModel , "forward" , _falcon_model_forward )
94
97
replace_customized_linear_with_linear (model )
95
98
convert_class (model , FalconDecoderLayer , _IPEXFalconDecoderLayer , model .config )
96
99
return model
@@ -102,9 +105,10 @@ def _patch_gpt2_model(model):
102
105
1. Disable SDPA so the attention mask will be compatible to ipex attention.
103
106
2. Use IAKV cache
104
107
"""
105
- model .transformer ._attn_implementation = "eager"
108
+ num_key_value_heads = model .config .num_attention_heads
109
+ setattr (model .config , "num_key_value_heads" , num_key_value_heads )
110
+ convert_functions (model , GPT2Model , "forward" , _gpt2_model_forward )
106
111
convert_class (model , GPT2Attention , _IPEXGPT2Attention , model .config )
107
- convert_functions (model , GPT2Block , "forward" , _gpt2_block_forward )
108
112
return model
109
113
110
114
0 commit comments