Skip to content

Commit 6e8cd3d

Browse files
Add IPEX model patcher (#567)
* llama model patcher * fix jit model * fix jit model * rm autocast in model * add llama model patcher * support assisted decoding and add reorder cache function * add comment for _prepare_past_key_values * rebase main * fix model_dtype * rm useless comments * fix llama * add comments for ipex_rope and ipex_scale_dot_product * fix comments * add enable_tpp comments * fix import * fix review aroun2 * add torch.no_grad to avoid auto_kernel_selection issue * use torch.no_grad in jit trace * fix ipex model testing * add tests for ipex model generation with multi inputs * fix code style * remove __get__(self) as _reorder_cache is static method for the class * fix reorder_cache * use model_type * check if reorder_cache is a static method * fix _reorder_cache * fix raise import error * test ipex patching * fix comments * update API name and testing * disable untill ipex version 2.5.0 * update testing name * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update tests/ipex/test_modeling.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * fix tests --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent c356aa3 commit 6e8cd3d

File tree

5 files changed

+558
-14
lines changed

5 files changed

+558
-14
lines changed

optimum/exporters/ipex/__init__.py

Whitespace-only changes.
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from transformers.models.llama.modeling_llama import (
16+
LlamaAttention,
17+
LlamaDecoderLayer,
18+
LlamaForCausalLM,
19+
LlamaModel,
20+
LlamaRMSNorm,
21+
)
22+
23+
from optimum.intel.utils.import_utils import is_ipex_version
24+
25+
from .modeling_utils import (
26+
_IPEXLlamaDecoderLayerRef,
27+
_llama_attn_forward,
28+
_llama_layer_norm_forward,
29+
_llama_model_forward,
30+
)
31+
32+
33+
_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
34+
_IPEX_EXPORTED_TASK = ("text-generation",)
35+
36+
37+
def convert_func(m, func_name, new_function):
38+
bound_method = new_function.__get__(m, m.__class__)
39+
setattr(m, func_name, bound_method)
40+
41+
42+
def convert_functions(m, target_m, new_function_name, new_function):
43+
for _, sub_m in m.named_children():
44+
if isinstance(sub_m, target_m):
45+
convert_func(sub_m, new_function_name, new_function)
46+
convert_functions(sub_m, target_m, new_function_name, new_function)
47+
48+
49+
def convert_class(m, target_m, new_class, config, distributed=False):
50+
for name, sub_m in m.named_children():
51+
if isinstance(sub_m, target_m):
52+
new_m = new_class(sub_m, config, distributed)
53+
setattr(m, name, new_m)
54+
convert_class(sub_m, target_m, new_class, config, distributed)
55+
56+
57+
def patch_op(m, target_m, new_op_name, new_op):
58+
for name, sub_m in m.named_children():
59+
if isinstance(sub_m, target_m):
60+
setattr(sub_m, new_op_name, new_op)
61+
patch_op(sub_m, target_m, new_op_name, new_op)
62+
63+
64+
def _patch_llama_model(model):
65+
if is_ipex_version("<", "2.5.0"):
66+
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")
67+
68+
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding
69+
70+
ipex_rope = RotaryEmbedding(
71+
model.config.max_position_embeddings,
72+
model.config.hidden_size // model.config.num_attention_heads,
73+
model.config.rope_theta,
74+
model.config.architectures[0],
75+
)
76+
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
77+
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
78+
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
79+
80+
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
81+
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
82+
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
83+
84+
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
85+
return model
86+
87+
88+
def _patch_model(model):
89+
if isinstance(model, LlamaForCausalLM):
90+
model = _patch_llama_model(model)
91+
return model
+307
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import List, Optional, Tuple, Union
17+
18+
import torch
19+
from torch import nn
20+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
21+
from transformers.modeling_outputs import BaseModelOutputWithPast
22+
from transformers.models.llama.modeling_llama import repeat_kv
23+
24+
from optimum.intel.utils.import_utils import is_ipex_version
25+
26+
27+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
28+
def _llama_layer_norm_forward(self, hidden_states):
29+
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
30+
31+
32+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
33+
def _llama_attn_forward(
34+
self,
35+
hidden_states: torch.Tensor,
36+
attention_mask: Optional[torch.Tensor] = None,
37+
position_ids: Optional[torch.LongTensor] = None,
38+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
39+
output_attentions: bool = False,
40+
use_cache: bool = False,
41+
**kwargs,
42+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43+
bsz, q_len, _ = hidden_states.size()
44+
45+
query = self.q_proj(hidden_states)
46+
key = self.k_proj(hidden_states)
47+
value = self.v_proj(hidden_states)
48+
49+
kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len
50+
51+
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
52+
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
53+
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
54+
# Use ipex op to rotary position embedding more efficient.
55+
key = self.ipex_rope(
56+
key,
57+
position_ids,
58+
self.num_key_value_heads,
59+
self.head_dim,
60+
self.head_dim // 2,
61+
self.head_dim,
62+
kv_seq_len,
63+
)
64+
query = self.ipex_rope(
65+
query,
66+
position_ids,
67+
self.num_heads,
68+
self.head_dim,
69+
self.head_dim // 2,
70+
self.head_dim,
71+
kv_seq_len,
72+
)
73+
74+
if use_cache:
75+
# This ipex op pre-allocates buffers for past_key_values and use beam index history
76+
# which to decide which beam should be used to make attention scale dot more efficient.
77+
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
78+
query,
79+
key,
80+
value,
81+
math.sqrt(self.head_dim),
82+
past_key_value,
83+
None,
84+
attention_mask,
85+
)
86+
else:
87+
value_states = value.transpose(1, 2)
88+
query_states = query.transpose(1, 2)
89+
key_states = key.transpose(1, 2)
90+
kv_seq_len = key_states.shape[-2]
91+
92+
past_key_value = None
93+
# repeat k/v heads if n_kv_heads < n_heads
94+
key_states = repeat_kv(key_states, self.num_key_value_groups)
95+
value_states = repeat_kv(value_states, self.num_key_value_groups)
96+
97+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
98+
99+
if attention_mask is not None:
100+
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
101+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
102+
103+
# upcast attention to fp32
104+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
105+
attn_output = torch.matmul(attn_weights, value_states)
106+
107+
attn_output = attn_output.transpose(1, 2)
108+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
109+
110+
if not output_attentions:
111+
attn_weights = None
112+
113+
return attn_output, attn_weights, past_key_value
114+
115+
116+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
117+
def _llama_model_forward(
118+
self,
119+
input_ids: torch.LongTensor = None,
120+
attention_mask: Optional[torch.Tensor] = None,
121+
position_ids: Optional[torch.LongTensor] = None,
122+
past_key_values: Optional[List[torch.FloatTensor]] = None,
123+
inputs_embeds: Optional[torch.FloatTensor] = None,
124+
use_cache: Optional[bool] = None,
125+
output_attentions: Optional[bool] = None,
126+
output_hidden_states: Optional[bool] = None,
127+
return_dict: Optional[bool] = None,
128+
**kwargs,
129+
) -> Union[Tuple, BaseModelOutputWithPast]:
130+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
131+
output_hidden_states = (
132+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
133+
)
134+
use_cache = use_cache if use_cache is not None else self.config.use_cache
135+
136+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
137+
138+
# retrieve input_ids and inputs_embeds
139+
if input_ids is not None and inputs_embeds is not None:
140+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
141+
elif input_ids is not None:
142+
batch_size, seq_length = input_ids.shape[:2]
143+
elif inputs_embeds is not None:
144+
batch_size, seq_length = inputs_embeds.shape[:2]
145+
else:
146+
raise ValueError("You have to specify either input_ids or inputs_embeds")
147+
148+
past_key_values_length = 0
149+
if past_key_values is not None:
150+
past_key_values_length = past_key_values[0][0].shape[2]
151+
152+
if position_ids is None:
153+
device = input_ids.device if input_ids is not None else inputs_embeds.device
154+
position_ids = torch.arange(
155+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
156+
)
157+
position_ids = position_ids.unsqueeze(0)
158+
159+
if inputs_embeds is None:
160+
inputs_embeds = self.embed_tokens(input_ids)
161+
162+
if getattr(self.config, "_flash_attn_2_enabled", False):
163+
# 2d mask is passed through the layers
164+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
165+
else:
166+
# 4d mask is passed through the layers
167+
attention_mask = _prepare_4d_causal_attention_mask(
168+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
169+
)
170+
171+
# embed positions
172+
hidden_states = inputs_embeds
173+
174+
# decoder layers
175+
all_hidden_states = () if output_hidden_states else None
176+
all_self_attns = () if output_attentions else None
177+
next_decoder_cache = () if use_cache else None
178+
179+
for idx, decoder_layer in enumerate(self.layers):
180+
if output_hidden_states:
181+
all_hidden_states += (hidden_states,)
182+
183+
past_key_value = past_key_values[idx] if past_key_values is not None else None
184+
185+
layer_outputs = decoder_layer(
186+
hidden_states,
187+
attention_mask=attention_mask,
188+
position_ids=position_ids,
189+
past_key_value=past_key_value,
190+
output_attentions=output_attentions,
191+
use_cache=use_cache,
192+
)
193+
194+
hidden_states = layer_outputs[0]
195+
196+
if use_cache:
197+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
198+
199+
if output_attentions:
200+
all_self_attns += (layer_outputs[1],)
201+
202+
hidden_states = self.norm(hidden_states)
203+
204+
# add hidden states from the last decoder layer
205+
if output_hidden_states:
206+
all_hidden_states += (hidden_states,)
207+
208+
next_cache = next_decoder_cache if use_cache else None
209+
if not return_dict:
210+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
211+
return BaseModelOutputWithPast(
212+
last_hidden_state=hidden_states,
213+
past_key_values=next_cache,
214+
hidden_states=all_hidden_states,
215+
attentions=all_self_attns,
216+
)
217+
218+
219+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
220+
class _IPEXLlamaDecoderLayerRef(nn.Module):
221+
def __init__(self, module, config, distributed=False):
222+
if is_ipex_version("<", "2.5.0"):
223+
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
224+
225+
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
226+
227+
super().__init__()
228+
for k, v in module.__dict__.items():
229+
setattr(self, k, v)
230+
for k, v in module.__class__.__dict__.items():
231+
if k.startswith("__") or k.startswith("forward"):
232+
continue
233+
setattr(self.__class__, k, getattr(module.__class__, k))
234+
self.distributed = distributed
235+
if not self.distributed:
236+
self.mha_linear_add = LinearAdd(module.self_attn.o_proj)
237+
self.mlp_linear_add = LinearAdd(module.mlp.down_proj)
238+
del self.__dict__["_modules"]["self_attn"].o_proj
239+
del self.__dict__["_modules"]["mlp"].down_proj
240+
self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
241+
del self.__dict__["_modules"]["mlp"].gate_proj
242+
del self.__dict__["_modules"]["mlp"].up_proj
243+
244+
def forward(
245+
self,
246+
hidden_states: torch.Tensor,
247+
attention_mask: Optional[torch.Tensor] = None,
248+
position_ids: Optional[torch.LongTensor] = None,
249+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
250+
output_attentions: Optional[bool] = False,
251+
use_cache: Optional[bool] = False,
252+
**kwargs,
253+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
254+
"""
255+
Args:
256+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
257+
attention_mask (`torch.FloatTensor`, *optional*):
258+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
259+
query_sequence_length, key_sequence_length)` if default attention is used.
260+
output_attentions (`bool`, *optional*):
261+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
262+
returned tensors for more detail.
263+
use_cache (`bool`, *optional*):
264+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
265+
(see `past_key_values`).
266+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
267+
"""
268+
269+
residual = hidden_states
270+
hidden_states = self.input_layernorm(hidden_states)
271+
272+
# Self Attention
273+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
274+
hidden_states=hidden_states,
275+
attention_mask=attention_mask,
276+
position_ids=position_ids,
277+
past_key_value=past_key_value,
278+
output_attentions=output_attentions,
279+
use_cache=use_cache,
280+
)
281+
if not self.distributed:
282+
hidden_states = self.mha_linear_add(hidden_states, residual)
283+
else:
284+
hidden_states = self.self_attn.o_proj(hidden_states)
285+
hidden_states = residual + hidden_states
286+
287+
# Fully Connected
288+
residual = hidden_states
289+
hidden_states = self.post_attention_layernorm(hidden_states)
290+
291+
mlp_gate = self.linear_silu_mul(hidden_states)
292+
293+
if not self.distributed:
294+
hidden_states = self.mlp_linear_add(mlp_gate, residual)
295+
else:
296+
hidden_states = self.mlp.down_proj(mlp_gate)
297+
hidden_states = residual + hidden_states
298+
299+
outputs = (hidden_states,)
300+
301+
if output_attentions:
302+
outputs += (self_attn_weights,)
303+
304+
if use_cache:
305+
outputs += (present_key_value,)
306+
307+
return outputs

0 commit comments

Comments
 (0)