Skip to content

Commit 89e10d6

Browse files
committed
add xpu port
1 parent 3b8900d commit 89e10d6

File tree

2 files changed

+132
-60
lines changed

2 files changed

+132
-60
lines changed

optimum/exporters/ipex/modeling_utils.py

+117-39
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,6 @@ def _llama_model_forward(
178178
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
179179
class _IPEXLlamaAttentionRef(nn.Module):
180180
def __init__(self, module, config, distributed=False) -> None:
181-
if is_ipex_version("<", "2.3.0"):
182-
raise ImportError(
183-
"Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
184-
)
185-
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
186-
187181
super().__init__()
188182
for k, v in module.__dict__.items():
189183
setattr(self, k, v)
@@ -193,19 +187,33 @@ def __init__(self, module, config, distributed=False) -> None:
193187
setattr(self.__class__, k, getattr(module.__class__, k))
194188
self.config = config
195189
self.distributed = distributed
196-
if not self.distributed:
197-
self.mha_linear_add = LinearAdd(self.o_proj)
198-
del self.__dict__["_modules"]["o_proj"]
199-
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
200-
text_max_length=module.config.max_position_embeddings
201-
)
202-
self.ipex_rope = RotaryEmbedding(
203-
module.config.max_position_embeddings,
204-
module.config.hidden_size // module.config.num_attention_heads,
205-
module.config.rope_theta,
206-
module.config.architectures[0],
207-
)
208-
190+
self.module_device = module.q_proj.weight.device.type
191+
if self.module_device == "xpu":
192+
from intel_extension_for_pytorch.transformers.models.xpu.fusions.mha_fusion import _IPEXRopeXPU
193+
self.ipex_rope = _IPEXRopeXPU(
194+
module.config.max_position_embeddings,
195+
module.config.hidden_size // module.config.num_attention_heads,
196+
module.config.rope_theta,
197+
module.config.architectures[0],
198+
)
199+
self.port_parameters(module)
200+
torch.xpu.empty_cache()
201+
else:
202+
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
203+
if not self.distributed:
204+
self.mha_linear_add = LinearAdd(self.o_proj)
205+
del self.__dict__["_modules"]["o_proj"]
206+
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
207+
text_max_length=module.config.max_position_embeddings
208+
)
209+
self.ipex_rope = RotaryEmbedding(
210+
module.config.max_position_embeddings,
211+
module.config.hidden_size // module.config.num_attention_heads,
212+
module.config.rope_theta,
213+
module.config.architectures[0],
214+
)
215+
216+
209217
def forward(
210218
self,
211219
hidden_states: torch.Tensor,
@@ -310,9 +318,60 @@ def forward(
310318
attn_weights = None
311319

312320
return attn_output, past_key_value, attn_weights
313-
314-
315-
class _IPEXLlamaMLP(nn.Module):
321+
322+
def port_parameters(self, module):
323+
self.qkv_proj_bias = None
324+
self.qkv_proj_weight = None
325+
if self.num_heads == self.num_key_value_heads:
326+
q_proj = module.q_proj.weight.transpose(0, 1)
327+
k_proj = module.k_proj.weight.transpose(0, 1)
328+
v_proj = module.v_proj.weight.transpose(0, 1)
329+
self.qkv_proj_weight = torch.stack([q_proj, k_proj, v_proj]).contiguous().view([3, -1, q_proj.shape[-1]])
330+
module.q_proj.weight.data = self.qkv_proj_weight[0, :, :].transpose(0, 1)
331+
module.k_proj.weight.data = self.qkv_proj_weight[1, :, :].transpose(0, 1)
332+
module.v_proj.weight.data = self.qkv_proj_weight[2, :, :].transpose(0, 1)
333+
if module.q_proj.bias is not None:
334+
self.qkv_proj_bias = (
335+
torch.stack([module.q_proj.bias, module.k_proj.bias, module.v_proj.bias])
336+
.contiguous()
337+
.view([3, -1])
338+
)
339+
module.q_proj.bias.data = self.qkv_proj_bias[0]
340+
module.k_proj.bias.data = self.qkv_proj_bias[1]
341+
module.v_proj.bias.data = self.qkv_proj_bias[2]
342+
else:
343+
q_proj = module.q_proj.weight.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim, self.embed_dim)
344+
k_proj = module.k_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim)
345+
v_proj = module.v_proj.weight.view(self.num_kv_heads, 1, self.head_dim, self.embed_dim)
346+
self.qkv_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=1).view(
347+
[self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim, self.embed_dim]
348+
)
349+
module.q_proj.data = self.qkv_proj_weight[:, :self.num_key_value_groups, :, :].view(
350+
[self.num_kv_heads * self.num_key_value_groups * self.head_dim, self.embed_dim]
351+
)
352+
module.k_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups, :, :].view(
353+
[self.num_kv_heads * self.head_dim, self.embed_dim]
354+
)
355+
module.v_proj.data = self.qkv_proj_weight[:, self.num_key_value_groups + 1, :, :].view(
356+
[self.num_kv_heads * self.head_dim, self.embed_dim]
357+
)
358+
if module.q_proj.bias is not None:
359+
q_bias = module.q_proj.bias.view(self.num_kv_heads, self.num_key_value_groups, self.head_dim)
360+
k_bias = module.k_proj.bias.view(self.num_kv_heads, 1, self.head_dim)
361+
v_bias = module.v_proj.bias.view(self.num_kv_heads, 1, self.head_dim)
362+
self.qkv_proj_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(
363+
[self.num_kv_heads, self.num_key_value_groups + 2, self.head_dim]
364+
)
365+
module.q_proj.bias.data = self.qkv_proj_bias[:, :self.num_key_value_groups, self.head_dim].view(-1)
366+
module.k_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups, self.head_dim].view(-1)
367+
module.v_proj.bias.data = self.qkv_proj_bias[:, self.num_key_value_groups + 1, self.head_dim].view(-1)
368+
369+
self.o_proj_weight = module.o_proj.weight.transpose(0, 1).contiguous()
370+
module.o_proj.weight.data = self.o_proj_weight.transpose(0, 1)
371+
self.o_proj_bias = module.o_proj.bias
372+
373+
374+
class _IPEXLlamaMLPRef(nn.Module):
316375
def __init__(self, module, config, distributed=False) -> None:
317376
if is_ipex_version("<", "2.3.0"):
318377
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
@@ -327,31 +386,50 @@ def __init__(self, module, config, distributed=False) -> None:
327386
setattr(self.__class__, k, getattr(module.__class__, k))
328387
self.config = config
329388
self.distributed = distributed
330-
if not self.distributed:
331-
self.mlp_linear_add = LinearAdd(module.down_proj)
332-
del self.__dict__["_modules"]["down_proj"]
333-
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
334-
del self.__dict__["_modules"]["gate_proj"]
335-
del self.__dict__["_modules"]["up_proj"]
389+
self.module_device = module.gate_proj.weight.device.type
390+
if self.module_device == "xpu":
391+
self.port_parameter(module)
392+
torch.xpu.empty_cache()
393+
else:
394+
if not self.distributed:
395+
self.mlp_linear_add = LinearAdd(module.down_proj)
396+
del self.__dict__["_modules"]["down_proj"]
397+
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
398+
del self.__dict__["_modules"]["gate_proj"]
399+
del self.__dict__["_modules"]["up_proj"]
336400

337401
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
338402
"""
339403
Args:
340404
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
341405
"""
342-
if hasattr(self, "linear_silu_mul"):
343-
mlp_gate = self.linear_silu_mul(hidden_states)
344-
if hasattr(self, "mlp_linear_add"):
345-
hidden_states = self.mlp_linear_add(mlp_gate, residual)
346-
else:
347-
hidden_states = self.down_proj(mlp_gate)
348-
hidden_states = residual + hidden_states
406+
if self.module_device == "xpu":
407+
up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight)
408+
hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up)
409+
hidden_states = matmul_add_add(hidden_states, self.down_proj_weight, self.down_proj_bias, residual)
349410
else:
350-
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
351-
hidden_states = residual + hidden_states
352-
411+
if hasattr(self, "linear_silu_mul"):
412+
mlp_gate = self.linear_silu_mul(hidden_states)
413+
if hasattr(self, "mlp_linear_add"):
414+
hidden_states = self.mlp_linear_add(mlp_gate, residual)
415+
else:
416+
hidden_states = self.down_proj(mlp_gate)
417+
hidden_states = residual + hidden_states
418+
else:
419+
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
420+
hidden_states = residual + hidden_states
353421
return hidden_states
354422

423+
def port_parameter(self, module):
424+
self.up_proj_weight = module.up_proj.weight.transpose(0, 1).contiguous()
425+
module.up_proj.weight.data = self.up_proj_weight.transpose(0, 1)
426+
self.gate_proj_weight = module.gate_proj.weight.transpose(0, 1).contiguous()
427+
module.gate_proj.weight.data = self.gate_proj_weight.transpose(0, 1)
428+
self.down_proj_weight = module.down_proj.weight.transpose(0, 1).contiguous()
429+
module.down_proj.weight.data = self.down_proj_weight.transpose(0, 1)
430+
self.up_proj_bias = module.up_proj.bias
431+
self.gate_proj_bias = module.gate_proj.bias
432+
self.down_proj_bias = module.down_proj.bias
355433

356434
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
357435
class _IPEXLlamaDecoderLayerRef(nn.Module):
@@ -365,7 +443,7 @@ def __init__(self, module, config, distributed=False):
365443
setattr(self.__class__, k, getattr(module.__class__, k))
366444
self.distributed = distributed
367445
self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed)
368-
self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)
446+
self.mlp = _IPEXLlamaMLPRef(module.mlp, config, distributed)
369447

370448
def forward(
371449
self,

optimum/intel/ipex/modeling_base.py

+15-21
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,12 @@ def __init__(
140140
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
141141
self.model_save_dir = model_save_dir
142142
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
143-
144-
self.input_names = {
145-
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
146-
}
143+
if self._device.type == "cpu":
144+
self.input_names = {
145+
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
146+
}
147+
else:
148+
self.input_names = {"past_key_values": None, "position_ids": None}
147149
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
148150
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
149151
AutoConfig.register(self.base_model_prefix, AutoConfig)
@@ -169,7 +171,6 @@ def _from_transformers(
169171
trust_remote_code: bool = False,
170172
_commit_hash: str = None,
171173
):
172-
device_map = kwargs.pop("device_map", None)
173174
if use_auth_token is not None:
174175
warnings.warn(
175176
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
@@ -196,22 +197,15 @@ def _from_transformers(
196197

197198
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
198199

199-
if "cpu" in str(model.device):
200-
if is_torch_version("<", "2.1.0"):
201-
raise ImportError("`torch>=2.1.0` is needed to trace your model")
202-
traced_model = ipex_jit_trace(model, task, use_cache)
203-
config.torchscript = True
204-
config.torch_dtype = torch_dtype
205-
return cls(traced_model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False)
206-
else:
207-
from optimum.exporters.ipex.model_patcher import _patch_model
208-
200+
if is_torch_xpu_available(check_device=True):
201+
model.to("xpu:0")
209202
if _is_patched_with_ipex(model, task):
210203
model = _patch_model(model)
211-
else:
212-
raise NotImplementedError(f"The given model is not support yet")
213-
214-
return model
204+
else:
205+
model = ipex_jit_trace(model, task, use_cache)
206+
config.torchscript = True
207+
config.torch_dtype = torch_dtype
208+
return cls(model, config=config, model_save_dir=model_id, use_cache=use_cache, warmup=False)
215209

216210
@classmethod
217211
def _from_pretrained(
@@ -462,7 +456,7 @@ def __init__(
462456
except AttributeError:
463457
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
464458

465-
if self._is_ipex_exported:
459+
if self._is_ipex_exported and self._device.type == "cpu":
466460
self._reorder_cache = _ipex_reorder_cache
467461
else:
468462
# Check if _reorder_cache is a static method
@@ -552,7 +546,7 @@ def forward(
552546
if "position_ids" in self.input_names or not self.input_names:
553547
inputs["position_ids"] = position_ids
554548

555-
if self.use_cache:
549+
if self.use_cache and self._device.type == "cpu":
556550
if past_key_values is None:
557551
past_key_values = self._prepare_past_key_values(input_ids)
558552

0 commit comments

Comments
 (0)