@@ -178,12 +178,6 @@ def _llama_model_forward(
178
178
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
179
179
class _IPEXLlamaAttentionRef (nn .Module ):
180
180
def __init__ (self , module , config , distributed = False ) -> None :
181
- if is_ipex_version ("<" , "2.3.0" ):
182
- raise ImportError (
183
- "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
184
- )
185
- from intel_extension_for_pytorch .llm .modules import IndirectAccessKVCacheAttention , LinearAdd , RotaryEmbedding
186
-
187
181
super ().__init__ ()
188
182
for k , v in module .__dict__ .items ():
189
183
setattr (self , k , v )
@@ -193,19 +187,33 @@ def __init__(self, module, config, distributed=False) -> None:
193
187
setattr (self .__class__ , k , getattr (module .__class__ , k ))
194
188
self .config = config
195
189
self .distributed = distributed
196
- if not self .distributed :
197
- self .mha_linear_add = LinearAdd (self .o_proj )
198
- del self .__dict__ ["_modules" ]["o_proj" ]
199
- self .ipex_scale_dot_product = IndirectAccessKVCacheAttention (
200
- text_max_length = module .config .max_position_embeddings
201
- )
202
- self .ipex_rope = RotaryEmbedding (
203
- module .config .max_position_embeddings ,
204
- module .config .hidden_size // module .config .num_attention_heads ,
205
- module .config .rope_theta ,
206
- module .config .architectures [0 ],
207
- )
208
-
190
+ self .module_device = module .q_proj .weight .device .type
191
+ if self .module_device == "xpu" :
192
+ from intel_extension_for_pytorch .transformers .models .xpu .fusions .mha_fusion import _IPEXRopeXPU
193
+ self .ipex_rope = _IPEXRopeXPU (
194
+ module .config .max_position_embeddings ,
195
+ module .config .hidden_size // module .config .num_attention_heads ,
196
+ module .config .rope_theta ,
197
+ module .config .architectures [0 ],
198
+ )
199
+ self .port_parameters (module )
200
+ torch .xpu .empty_cache ()
201
+ else :
202
+ from intel_extension_for_pytorch .llm .modules import IndirectAccessKVCacheAttention , LinearAdd , RotaryEmbedding
203
+ if not self .distributed :
204
+ self .mha_linear_add = LinearAdd (self .o_proj )
205
+ del self .__dict__ ["_modules" ]["o_proj" ]
206
+ self .ipex_scale_dot_product = IndirectAccessKVCacheAttention (
207
+ text_max_length = module .config .max_position_embeddings
208
+ )
209
+ self .ipex_rope = RotaryEmbedding (
210
+ module .config .max_position_embeddings ,
211
+ module .config .hidden_size // module .config .num_attention_heads ,
212
+ module .config .rope_theta ,
213
+ module .config .architectures [0 ],
214
+ )
215
+
216
+
209
217
def forward (
210
218
self ,
211
219
hidden_states : torch .Tensor ,
@@ -310,9 +318,60 @@ def forward(
310
318
attn_weights = None
311
319
312
320
return attn_output , past_key_value , attn_weights
313
-
314
-
315
- class _IPEXLlamaMLP (nn .Module ):
321
+
322
+ def port_parameters (self , module ):
323
+ self .qkv_proj_bias = None
324
+ self .qkv_proj_weight = None
325
+ if self .num_heads == self .num_key_value_heads :
326
+ q_proj = module .q_proj .weight .transpose (0 , 1 )
327
+ k_proj = module .k_proj .weight .transpose (0 , 1 )
328
+ v_proj = module .v_proj .weight .transpose (0 , 1 )
329
+ self .qkv_proj_weight = torch .stack ([q_proj , k_proj , v_proj ]).contiguous ().view ([3 , - 1 , q_proj .shape [- 1 ]])
330
+ module .q_proj .weight .data = self .qkv_proj_weight [0 , :, :].transpose (0 , 1 )
331
+ module .k_proj .weight .data = self .qkv_proj_weight [1 , :, :].transpose (0 , 1 )
332
+ module .v_proj .weight .data = self .qkv_proj_weight [2 , :, :].transpose (0 , 1 )
333
+ if module .q_proj .bias is not None :
334
+ self .qkv_proj_bias = (
335
+ torch .stack ([module .q_proj .bias , module .k_proj .bias , module .v_proj .bias ])
336
+ .contiguous ()
337
+ .view ([3 , - 1 ])
338
+ )
339
+ module .q_proj .bias .data = self .qkv_proj_bias [0 ]
340
+ module .k_proj .bias .data = self .qkv_proj_bias [1 ]
341
+ module .v_proj .bias .data = self .qkv_proj_bias [2 ]
342
+ else :
343
+ q_proj = module .q_proj .weight .view (self .num_kv_heads , self .num_key_value_groups , self .head_dim , self .embed_dim )
344
+ k_proj = module .k_proj .weight .view (self .num_kv_heads , 1 , self .head_dim , self .embed_dim )
345
+ v_proj = module .v_proj .weight .view (self .num_kv_heads , 1 , self .head_dim , self .embed_dim )
346
+ self .qkv_proj_weight = torch .cat ([q_proj , k_proj , v_proj ], dim = 1 ).view (
347
+ [self .num_kv_heads , self .num_key_value_groups + 2 , self .head_dim , self .embed_dim ]
348
+ )
349
+ module .q_proj .data = self .qkv_proj_weight [:, :self .num_key_value_groups , :, :].view (
350
+ [self .num_kv_heads * self .num_key_value_groups * self .head_dim , self .embed_dim ]
351
+ )
352
+ module .k_proj .data = self .qkv_proj_weight [:, self .num_key_value_groups , :, :].view (
353
+ [self .num_kv_heads * self .head_dim , self .embed_dim ]
354
+ )
355
+ module .v_proj .data = self .qkv_proj_weight [:, self .num_key_value_groups + 1 , :, :].view (
356
+ [self .num_kv_heads * self .head_dim , self .embed_dim ]
357
+ )
358
+ if module .q_proj .bias is not None :
359
+ q_bias = module .q_proj .bias .view (self .num_kv_heads , self .num_key_value_groups , self .head_dim )
360
+ k_bias = module .k_proj .bias .view (self .num_kv_heads , 1 , self .head_dim )
361
+ v_bias = module .v_proj .bias .view (self .num_kv_heads , 1 , self .head_dim )
362
+ self .qkv_proj_bias = torch .cat ([q_bias , k_bias , v_bias ], dim = 1 ).view (
363
+ [self .num_kv_heads , self .num_key_value_groups + 2 , self .head_dim ]
364
+ )
365
+ module .q_proj .bias .data = self .qkv_proj_bias [:, :self .num_key_value_groups , self .head_dim ].view (- 1 )
366
+ module .k_proj .bias .data = self .qkv_proj_bias [:, self .num_key_value_groups , self .head_dim ].view (- 1 )
367
+ module .v_proj .bias .data = self .qkv_proj_bias [:, self .num_key_value_groups + 1 , self .head_dim ].view (- 1 )
368
+
369
+ self .o_proj_weight = module .o_proj .weight .transpose (0 , 1 ).contiguous ()
370
+ module .o_proj .weight .data = self .o_proj_weight .transpose (0 , 1 )
371
+ self .o_proj_bias = module .o_proj .bias
372
+
373
+
374
+ class _IPEXLlamaMLPRef (nn .Module ):
316
375
def __init__ (self , module , config , distributed = False ) -> None :
317
376
if is_ipex_version ("<" , "2.3.0" ):
318
377
raise ImportError ("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd" )
@@ -327,31 +386,50 @@ def __init__(self, module, config, distributed=False) -> None:
327
386
setattr (self .__class__ , k , getattr (module .__class__ , k ))
328
387
self .config = config
329
388
self .distributed = distributed
330
- if not self .distributed :
331
- self .mlp_linear_add = LinearAdd (module .down_proj )
332
- del self .__dict__ ["_modules" ]["down_proj" ]
333
- self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
334
- del self .__dict__ ["_modules" ]["gate_proj" ]
335
- del self .__dict__ ["_modules" ]["up_proj" ]
389
+ self .module_device = module .gate_proj .weight .device .type
390
+ if self .module_device == "xpu" :
391
+ self .port_parameter (module )
392
+ torch .xpu .empty_cache ()
393
+ else :
394
+ if not self .distributed :
395
+ self .mlp_linear_add = LinearAdd (module .down_proj )
396
+ del self .__dict__ ["_modules" ]["down_proj" ]
397
+ self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
398
+ del self .__dict__ ["_modules" ]["gate_proj" ]
399
+ del self .__dict__ ["_modules" ]["up_proj" ]
336
400
337
401
def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor = None , ** kwargs ):
338
402
"""
339
403
Args:
340
404
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
341
405
"""
342
- if hasattr (self , "linear_silu_mul" ):
343
- mlp_gate = self .linear_silu_mul (hidden_states )
344
- if hasattr (self , "mlp_linear_add" ):
345
- hidden_states = self .mlp_linear_add (mlp_gate , residual )
346
- else :
347
- hidden_states = self .down_proj (mlp_gate )
348
- hidden_states = residual + hidden_states
406
+ if self .module_device == "xpu" :
407
+ up = torch .ops .torch_ipex .mm_silu (hidden_states , self .gate_proj_weight )
408
+ hidden_states = torch .ops .torch_ipex .mm_resmul (hidden_states , self .up_proj_weight , up )
409
+ hidden_states = matmul_add_add (hidden_states , self .down_proj_weight , self .down_proj_bias , residual )
349
410
else :
350
- hidden_states = self .down_proj (self .act_fn (self .gate_proj (hidden_states )) * self .up_proj (hidden_states ))
351
- hidden_states = residual + hidden_states
352
-
411
+ if hasattr (self , "linear_silu_mul" ):
412
+ mlp_gate = self .linear_silu_mul (hidden_states )
413
+ if hasattr (self , "mlp_linear_add" ):
414
+ hidden_states = self .mlp_linear_add (mlp_gate , residual )
415
+ else :
416
+ hidden_states = self .down_proj (mlp_gate )
417
+ hidden_states = residual + hidden_states
418
+ else :
419
+ hidden_states = self .down_proj (self .act_fn (self .gate_proj (hidden_states )) * self .up_proj (hidden_states ))
420
+ hidden_states = residual + hidden_states
353
421
return hidden_states
354
422
423
+ def port_parameter (self , module ):
424
+ self .up_proj_weight = module .up_proj .weight .transpose (0 , 1 ).contiguous ()
425
+ module .up_proj .weight .data = self .up_proj_weight .transpose (0 , 1 )
426
+ self .gate_proj_weight = module .gate_proj .weight .transpose (0 , 1 ).contiguous ()
427
+ module .gate_proj .weight .data = self .gate_proj_weight .transpose (0 , 1 )
428
+ self .down_proj_weight = module .down_proj .weight .transpose (0 , 1 ).contiguous ()
429
+ module .down_proj .weight .data = self .down_proj_weight .transpose (0 , 1 )
430
+ self .up_proj_bias = module .up_proj .bias
431
+ self .gate_proj_bias = module .gate_proj .bias
432
+ self .down_proj_bias = module .down_proj .bias
355
433
356
434
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
357
435
class _IPEXLlamaDecoderLayerRef (nn .Module ):
@@ -365,7 +443,7 @@ def __init__(self, module, config, distributed=False):
365
443
setattr (self .__class__ , k , getattr (module .__class__ , k ))
366
444
self .distributed = distributed
367
445
self .self_attn = _IPEXLlamaAttentionRef (module .self_attn , config , distributed )
368
- self .mlp = _IPEXLlamaMLP (module .mlp , config , distributed )
446
+ self .mlp = _IPEXLlamaMLPRef (module .mlp , config , distributed )
369
447
370
448
def forward (
371
449
self ,
0 commit comments