Skip to content

Commit 872a3eb

Browse files
committed
remove reference elimination
1 parent 3824300 commit 872a3eb

File tree

1 file changed

+2
-36
lines changed

1 file changed

+2
-36
lines changed

optimum/exporters/ipex/modeling/xpu/xpu_modeling_llama.py

+2-36
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ def matmul_add_add(attn_output, weight, bias=None, residual=None):
3232
return attn_output
3333

3434

35-
def reference_elimination(c, b):
36-
for item in gc.get_objects():
37-
if isinstance(item, torch.Tensor) and item.data_ptr() == c.data_ptr() and item is not c:
38-
item.data = b
39-
40-
4135
class _IPEXLlamaAttentionXPU(_IPEXLlamaAttention):
4236
def __init__(self, module, config, distributed=False, optimized_module=None) -> None:
4337
super().__init__(module, config, distributed)
@@ -166,23 +160,17 @@ def port_parameters(self, module):
166160
k_proj = module.k_proj.weight.transpose(0, 1)
167161
v_proj = module.v_proj.weight.transpose(0, 1)
168162
self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]])
169-
reference_elimination(module.q_proj.weight.data, self.qkv_proj_weight[0, :, :].transpose(0, 1))
170163
module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1)
171-
reference_elimination(module.k_proj.weight.data, self.qkv_proj_weight[1, :, :].transpose(0, 1))
172164
module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1)
173-
reference_elimination(module.v_proj.weight.data, self.qkv_proj_weight[2, :, :].transpose(0, 1))
174165
module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1)
175166
if module.q_proj.bias is not None:
176167
self.qkv_proj_bias = (
177168
torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias])
178169
.contiguous()
179170
.view([3, -1])
180171
)
181-
reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[0])
182172
module.q_proj.bias.data = self.qkv_proj_bias[0]
183-
reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[1])
184173
module.k_proj.bias.data = self.qkv_proj_bias[1]
185-
reference_elimination(module.v_proj.bias.data, self.qkv_proj_bias[2])
186174
module.v_proj.bias.data = self.qkv_proj_bias[2]
187175
else:
188176
group = self.num_heads // self.num_kv_heads
@@ -192,26 +180,12 @@ def port_parameters(self, module):
192180
self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view(
193181
[self.num_kv_heads, group + 2, self.head_dim, self.embed_dim]
194182
)
195-
reference_elimination(
196-
module.q_proj.data,
197-
self.qkv_proj_weight[:, :group, :, :].view(
198-
[self.num_kv_heads * group * self.head_dim, self.embed_dim]
199-
),
200-
)
201183
module.q_proj.data = self.qkv_proj_weight[:, :group, :, :].view(
202184
[self.num_kv_heads * group * self.head_dim, self.embed_dim]
203185
)
204-
reference_elimination(
205-
module.k_proj.data,
206-
self.qkv_proj_weight[:, group, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]),
207-
)
208186
module.k_proj.data = self.qkv_proj_weight[:, group, :, :].view(
209187
[self.num_kv_heads * self.head_dim, self.embed_dim]
210188
)
211-
reference_elimination(
212-
module.v_proj.data,
213-
self.qkv_proj_weight[:, group + 1, :, :].view([self.num_kv_heads * self.head_dim, self.embed_dim]),
214-
)
215189
module.v_proj.data = self.qkv_proj_weight[:, group + 1, :, :].view(
216190
[self.num_kv_heads * self.head_dim, self.embed_dim]
217191
)
@@ -222,16 +196,10 @@ def port_parameters(self, module):
222196
self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(
223197
[self.num_kv_heads, group + 2, self.head_dim]
224198
)
225-
reference_elimination(module.q_proj.bias.data, self.qkv_proj_bias[:, :group, self.head_dim].view(-1))
226199
module.q_proj.bias.data = self.qkv_proj_bias[:, :group, self.head_dim].view(-1)
227-
reference_elimination(module.k_proj.bias.data, self.qkv_proj_bias[:, group, self.head_dim].view(-1))
228200
module.k_proj.bias.data = self.qkv_proj_bias[:, group, self.head_dim].view(-1)
229-
reference_elimination(
230-
module.v_proj.bias.data, self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1)
231-
)
232201
module.v_proj.bias.data = self.qkv_proj_bias[:, group + 1, self.head_dim].view(-1)
233202
self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous()
234-
reference_elimination(module.o_proj.weight.data, self.o_proj_weight.transpose(0, 1))
235203
module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1)
236204
self.o_proj_bias = module.o_proj.bias
237205

@@ -243,7 +211,8 @@ def __init__(self, module, config, distributed=False, optimized_module=None) ->
243211
if optimized_module is not None:
244212
self.mlp_impl = optimized_module
245213
self.port_parameter(module)
246-
214+
torch.xpu.empty_cache()
215+
247216
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
248217
"""
249218
Args:
@@ -256,13 +225,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **
256225

257226
def port_parameter(self, module):
258227
self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous()
259-
reference_elimination(module.up_proj.weight.data, self.up_proj_weight.transpose(0, 1))
260228
module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1)
261229
self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous()
262-
reference_elimination(module.gate_proj.weight.data, self.gate_proj_weight.transpose(0, 1))
263230
module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1)
264231
self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous()
265-
reference_elimination(module.down_proj.weight.data, self.down_proj_weight.transpose(0, 1))
266232
module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1)
267233
self.up_proj_bias = module.up_proj.bias
268234
self.gate_proj_bias = module.gate_proj.bias

0 commit comments

Comments
 (0)