14
14
15
15
from transformers .models .bert .modeling_bert import BertIntermediate
16
16
from transformers .models .falcon .modeling_falcon import FalconDecoderLayer , FalconModel
17
- from transformers .models .gpt2 .modeling_gpt2 import GPT2MLP , GPT2Attention , GPT2Block , GPT2Model
17
+ from transformers .models .gpt2 .modeling_gpt2 import GPT2Block , GPT2Model
18
18
from transformers .models .llama .modeling_llama import (
19
19
LlamaDecoderLayer ,
20
20
LlamaModel ,
32
32
33
33
from .modeling_utils import (
34
34
_IPEX_MINIMUM_VERSION_FOR_PATCHING ,
35
- _IPEXGPT2MLP ,
36
35
_falcon_model_forward ,
37
- _gpt2_block_forward ,
38
36
_gpt2_model_forward ,
39
37
_ipex_rms_layer_norm_forward ,
40
38
_IPEXFalconDecoderLayer ,
41
- _IPEXGPT2Attention ,
39
+ _IPEXGPT2Block ,
42
40
_IPEXIntermediate ,
43
41
_IPEXLlamaDecoderLayer ,
44
42
_IPEXQwen2DecoderLayer ,
@@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
66
64
convert_functions (sub_m , target_m , new_function_name , new_function )
67
65
68
66
69
- def convert_class (m , target_m , new_class , config = None ):
67
+ def convert_class (m , target_m , new_class , device , config ):
70
68
for name , sub_m in m .named_children ():
71
69
if isinstance (sub_m , target_m ):
72
- new_m = new_class (sub_m , config )
70
+ new_m = new_class (sub_m , device , config )
73
71
setattr (m , name , new_m )
74
- convert_class (sub_m , target_m , new_class , config )
72
+ convert_class (sub_m , target_m , new_class , device , config )
75
73
76
74
77
75
def patch_op (m , target_m , new_op_name , new_op ):
@@ -89,7 +87,7 @@ def _patch_llama_model(model):
89
87
"""
90
88
convert_functions (model , LlamaModel , "forward" , _llama_model_forward )
91
89
convert_functions (model , LlamaRMSNorm , "forward" , _ipex_rms_layer_norm_forward )
92
- convert_class (model , LlamaDecoderLayer , _IPEXLlamaDecoderLayer , model .config )
90
+ convert_class (model , LlamaDecoderLayer , _IPEXLlamaDecoderLayer , model .device , model . config )
93
91
return model
94
92
95
93
@@ -105,21 +103,20 @@ def _patch_falcon_model(model):
105
103
setattr (model .config , "num_key_value_heads" , num_key_value_heads )
106
104
convert_functions (model , FalconModel , "forward" , _falcon_model_forward )
107
105
replace_customized_linear_with_linear (model )
108
- convert_class (model , FalconDecoderLayer , _IPEXFalconDecoderLayer , model .config )
106
+ convert_class (model , FalconDecoderLayer , _IPEXFalconDecoderLayer , model .device , model . config )
109
107
return model
110
108
111
109
112
110
def _patch_gpt2_model (model ):
113
111
"""
114
112
Patch gpt2 model:
115
113
1. Use IPEX paged attention
114
+ 2. Linear fusion with (Linear + Add)
116
115
"""
117
116
num_key_value_heads = model .config .num_attention_heads
118
117
setattr (model .config , "num_key_value_heads" , num_key_value_heads )
119
118
convert_functions (model , GPT2Model , "forward" , _gpt2_model_forward )
120
- convert_functions (model , GPT2Block , "forward" , _gpt2_block_forward )
121
- convert_class (model , GPT2Attention , _IPEXGPT2Attention , model .config )
122
- convert_class (model , GPT2MLP , _IPEXGPT2MLP , model .config )
119
+ convert_class (model , GPT2Block , _IPEXGPT2Block , model .device , model .config )
123
120
return model
124
121
125
122
@@ -131,7 +128,7 @@ def _patch_qwen2_model(model):
131
128
"""
132
129
convert_functions (model , Qwen2Model , "forward" , _qwen2_model_forward )
133
130
convert_functions (model , Qwen2RMSNorm , "forward" , _ipex_rms_layer_norm_forward )
134
- convert_class (model , Qwen2DecoderLayer , _IPEXQwen2DecoderLayer , model .config )
131
+ convert_class (model , Qwen2DecoderLayer , _IPEXQwen2DecoderLayer , model .device , model . config )
135
132
return model
136
133
137
134
@@ -140,7 +137,7 @@ def _patch_bert_model(model):
140
137
Patch bert model:
141
138
1. Linear fusion with Linear + Gelu
142
139
"""
143
- convert_class (model , BertIntermediate , _IPEXIntermediate )
140
+ convert_class (model , BertIntermediate , _IPEXIntermediate , model . device , model . config )
144
141
return model
145
142
146
143
@@ -149,7 +146,7 @@ def _patch_vit_model(model):
149
146
Patch vit model:
150
147
1. Linear fusion with Linear + Gelu
151
148
"""
152
- convert_class (model , ViTIntermediate , _IPEXIntermediate )
149
+ convert_class (model , ViTIntermediate , _IPEXIntermediate , model . device , model . config )
153
150
return model
154
151
155
152
0 commit comments