@@ -32,12 +32,6 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None):
32
32
return attn_output
33
33
34
34
35
- def reference_elimination (c , b ):
36
- for item in gc .get_objects ():
37
- if isinstance (item , torch .Tensor ) and item .data_ptr () == c .data_ptr () and item is not c :
38
- item .data = b
39
-
40
-
41
35
class _IPEXLlamaAttentionXPU (_IPEXLlamaAttention ):
42
36
def __init__ (self , module , config , distributed = False , optimized_module = None ) -> None :
43
37
super ().__init__ (module , config , distributed )
@@ -166,23 +160,17 @@ def port_parameters(self, module):
166
160
k_proj = module .k_proj .weight .transpose (0 , 1 )
167
161
v_proj = module .v_proj .weight .transpose (0 , 1 )
168
162
self .qkv_proj_weight = torch .stack ([q_proj , k_proj , v_proj ]).contiguous ().view ([3 , - 1 , q_proj .shape [- 1 ]])
169
- reference_elimination (module .q_proj .weight .data , self .qkv_proj_weight [0 , :, :].transpose (0 , 1 ))
170
163
module .q_proj .weight .data = self .qkv_proj_weight [0 , :, :].transpose (0 , 1 )
171
- reference_elimination (module .k_proj .weight .data , self .qkv_proj_weight [1 , :, :].transpose (0 , 1 ))
172
164
module .k_proj .weight .data = self .qkv_proj_weight [1 , :, :].transpose (0 , 1 )
173
- reference_elimination (module .v_proj .weight .data , self .qkv_proj_weight [2 , :, :].transpose (0 , 1 ))
174
165
module .v_proj .weight .data = self .qkv_proj_weight [2 , :, :].transpose (0 , 1 )
175
166
if module .q_proj .bias is not None :
176
167
self .qkv_proj_bias = (
177
168
torch .stack ([module .q_proj .bias , module .k_proj .bias , module .v_proj .bias ])
178
169
.contiguous ()
179
170
.view ([3 , - 1 ])
180
171
)
181
- reference_elimination (module .q_proj .bias .data , self .qkv_proj_bias [0 ])
182
172
module .q_proj .bias .data = self .qkv_proj_bias [0 ]
183
- reference_elimination (module .k_proj .bias .data , self .qkv_proj_bias [1 ])
184
173
module .k_proj .bias .data = self .qkv_proj_bias [1 ]
185
- reference_elimination (module .v_proj .bias .data , self .qkv_proj_bias [2 ])
186
174
module .v_proj .bias .data = self .qkv_proj_bias [2 ]
187
175
else :
188
176
group = self .num_heads // self .num_kv_heads
@@ -192,26 +180,12 @@ def port_parameters(self, module):
192
180
self .qkv_proj_weight = torch .cat ([q_proj , k_proj , v_proj ], dim = 1 ).view (
193
181
[self .num_kv_heads , group + 2 , self .head_dim , self .embed_dim ]
194
182
)
195
- reference_elimination (
196
- module .q_proj .data ,
197
- self .qkv_proj_weight [:, :group , :, :].view (
198
- [self .num_kv_heads * group * self .head_dim , self .embed_dim ]
199
- ),
200
- )
201
183
module .q_proj .data = self .qkv_proj_weight [:, :group , :, :].view (
202
184
[self .num_kv_heads * group * self .head_dim , self .embed_dim ]
203
185
)
204
- reference_elimination (
205
- module .k_proj .data ,
206
- self .qkv_proj_weight [:, group , :, :].view ([self .num_kv_heads * self .head_dim , self .embed_dim ]),
207
- )
208
186
module .k_proj .data = self .qkv_proj_weight [:, group , :, :].view (
209
187
[self .num_kv_heads * self .head_dim , self .embed_dim ]
210
188
)
211
- reference_elimination (
212
- module .v_proj .data ,
213
- self .qkv_proj_weight [:, group + 1 , :, :].view ([self .num_kv_heads * self .head_dim , self .embed_dim ]),
214
- )
215
189
module .v_proj .data = self .qkv_proj_weight [:, group + 1 , :, :].view (
216
190
[self .num_kv_heads * self .head_dim , self .embed_dim ]
217
191
)
@@ -222,16 +196,10 @@ def port_parameters(self, module):
222
196
self .qkv_proj_bias = torch .cat ([q_bias , k_bias , v_bias ], dim = 1 ).view (
223
197
[self .num_kv_heads , group + 2 , self .head_dim ]
224
198
)
225
- reference_elimination (module .q_proj .bias .data , self .qkv_proj_bias [:, :group , self .head_dim ].view (- 1 ))
226
199
module .q_proj .bias .data = self .qkv_proj_bias [:, :group , self .head_dim ].view (- 1 )
227
- reference_elimination (module .k_proj .bias .data , self .qkv_proj_bias [:, group , self .head_dim ].view (- 1 ))
228
200
module .k_proj .bias .data = self .qkv_proj_bias [:, group , self .head_dim ].view (- 1 )
229
- reference_elimination (
230
- module .v_proj .bias .data , self .qkv_proj_bias [:, group + 1 , self .head_dim ].view (- 1 )
231
- )
232
201
module .v_proj .bias .data = self .qkv_proj_bias [:, group + 1 , self .head_dim ].view (- 1 )
233
202
self .o_proj_weight = module .o_proj .weight .transpose (0 , 1 ).contiguous ()
234
- reference_elimination (module .o_proj .weight .data , self .o_proj_weight .transpose (0 , 1 ))
235
203
module .o_proj .weight .data = self .o_proj_weight .transpose (0 , 1 )
236
204
self .o_proj_bias = module .o_proj .bias
237
205
@@ -243,7 +211,8 @@ def __init__(self, module, config, distributed=False, optimized_module=None) ->
243
211
if optimized_module is not None :
244
212
self .mlp_impl = optimized_module
245
213
self .port_parameter (module )
246
-
214
+ torch .xpu .empty_cache ()
215
+
247
216
def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor = None , ** kwargs ):
248
217
"""
249
218
Args:
@@ -256,13 +225,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **
256
225
257
226
def port_parameter (self , module ):
258
227
self .up_proj_weight = module .up_proj .weight .transpose (0 , 1 ).contiguous ()
259
- reference_elimination (module .up_proj .weight .data , self .up_proj_weight .transpose (0 , 1 ))
260
228
module .up_proj .weight .data = self .up_proj_weight .transpose (0 , 1 )
261
229
self .gate_proj_weight = module .gate_proj .weight .transpose (0 , 1 ).contiguous ()
262
- reference_elimination (module .gate_proj .weight .data , self .gate_proj_weight .transpose (0 , 1 ))
263
230
module .gate_proj .weight .data = self .gate_proj_weight .transpose (0 , 1 )
264
231
self .down_proj_weight = module .down_proj .weight .transpose (0 , 1 ).contiguous ()
265
- reference_elimination (module .down_proj .weight .data , self .down_proj_weight .transpose (0 , 1 ))
266
232
module .down_proj .weight .data = self .down_proj_weight .transpose (0 , 1 )
267
233
self .up_proj_bias = module .up_proj .bias
268
234
self .gate_proj_bias = module .gate_proj .bias
0 commit comments