Skip to content

Commit d1d0ca0

Browse files
committed
refactor IPEXLlamaAttention
1 parent 5351f4a commit d1d0ca0

File tree

2 files changed

+184
-129
lines changed

2 files changed

+184
-129
lines changed

optimum/exporters/ipex/model_patcher.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from transformers.models.llama.modeling_llama import (
16-
LlamaAttention,
1716
LlamaDecoderLayer,
1817
LlamaForCausalLM,
1918
LlamaModel,
@@ -24,7 +23,6 @@
2423

2524
from .modeling_utils import (
2625
_IPEXLlamaDecoderLayerRef,
27-
_llama_attn_forward,
2826
_llama_layer_norm_forward,
2927
_llama_model_forward,
3028
)
@@ -63,24 +61,10 @@ def patch_op(m, target_m, new_op_name, new_op):
6361

6462
def _patch_llama_model(model):
6563
if is_ipex_version("<", "2.3.0"):
66-
raise ImportError("Only ipex version >= 2.3.0 supports RotaryEmbedding and IndirectAccessKVCacheAttention")
67-
68-
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding
69-
70-
ipex_rope = RotaryEmbedding(
71-
model.config.max_position_embeddings,
72-
model.config.hidden_size // model.config.num_attention_heads,
73-
model.config.rope_theta,
74-
model.config.architectures[0],
75-
)
76-
ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings)
77-
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
78-
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
64+
raise ImportError("Only ipex version >= 2.3.0 supports llama model patching")
7965

8066
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
81-
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
8267
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
83-
8468
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
8569
return model
8670

optimum/exporters/ipex/modeling_utils.py

+183-112
Original file line numberDiff line numberDiff line change
@@ -29,90 +29,6 @@ def _llama_layer_norm_forward(self, hidden_states):
2929
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
3030

3131

32-
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
33-
def _llama_attn_forward(
34-
self,
35-
hidden_states: torch.Tensor,
36-
attention_mask: Optional[torch.Tensor] = None,
37-
position_ids: Optional[torch.LongTensor] = None,
38-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
39-
output_attentions: bool = False,
40-
use_cache: bool = False,
41-
**kwargs,
42-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
43-
bsz, q_len, _ = hidden_states.size()
44-
45-
query = self.q_proj(hidden_states)
46-
key = self.k_proj(hidden_states)
47-
value = self.v_proj(hidden_states)
48-
49-
kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len
50-
51-
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
52-
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
53-
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
54-
# Use ipex op to rotary position embedding more efficient.
55-
key = self.ipex_rope(
56-
key,
57-
position_ids,
58-
self.num_key_value_heads,
59-
self.head_dim,
60-
self.head_dim // 2,
61-
self.head_dim,
62-
kv_seq_len,
63-
)
64-
query = self.ipex_rope(
65-
query,
66-
position_ids,
67-
self.num_heads,
68-
self.head_dim,
69-
self.head_dim // 2,
70-
self.head_dim,
71-
kv_seq_len,
72-
)
73-
74-
if use_cache:
75-
# This ipex op pre-allocates buffers for past_key_values and use beam index history
76-
# which to decide which beam should be used to make attention scale dot more efficient.
77-
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
78-
query,
79-
key,
80-
value,
81-
math.sqrt(self.head_dim),
82-
past_key_value,
83-
None,
84-
attention_mask,
85-
)
86-
else:
87-
value_states = value.transpose(1, 2)
88-
query_states = query.transpose(1, 2)
89-
key_states = key.transpose(1, 2)
90-
kv_seq_len = key_states.shape[-2]
91-
92-
past_key_value = None
93-
# repeat k/v heads if n_kv_heads < n_heads
94-
key_states = repeat_kv(key_states, self.num_key_value_groups)
95-
value_states = repeat_kv(value_states, self.num_key_value_groups)
96-
97-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
98-
99-
if attention_mask is not None:
100-
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
101-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
102-
103-
# upcast attention to fp32
104-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
105-
attn_output = torch.matmul(attn_weights, value_states)
106-
107-
attn_output = attn_output.transpose(1, 2)
108-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
109-
110-
if not output_attentions:
111-
attn_weights = None
112-
113-
return attn_output, attn_weights, past_key_value
114-
115-
11632
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
11733
def _llama_model_forward(
11834
self,
@@ -216,12 +132,147 @@ def _llama_model_forward(
216132
)
217133

218134

219-
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
220-
class _IPEXLlamaDecoderLayerRef(nn.Module):
221-
def __init__(self, module, config, distributed=False):
135+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
136+
class _IPEXLlamaAttentionRef(nn.Module):
137+
def __init__(self, module, config, distributed=False) -> None:
222138
if is_ipex_version("<", "2.3.0"):
223-
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
139+
raise ImportError(
140+
"Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
141+
)
142+
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
143+
144+
super().__init__()
145+
for k, v in module.__dict__.items():
146+
setattr(self, k, v)
147+
for k, v in module.__class__.__dict__.items():
148+
if k.startswith("__") or k.startswith("forward"):
149+
continue
150+
setattr(self.__class__, k, getattr(module.__class__, k))
151+
self.config = config
152+
self.distributed = distributed
153+
if not self.distributed:
154+
self.mha_linear_add = LinearAdd(self.o_proj)
155+
del self.__dict__["_modules"]["o_proj"]
156+
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
157+
text_max_length=module.config.max_position_embeddings
158+
)
159+
self.ipex_rope = RotaryEmbedding(
160+
module.config.max_position_embeddings,
161+
module.config.hidden_size // module.config.num_attention_heads,
162+
module.config.rope_theta,
163+
module.config.architectures[0],
164+
)
165+
166+
def forward(
167+
self,
168+
hidden_states: torch.Tensor,
169+
attention_mask: Optional[torch.Tensor] = None,
170+
position_ids: Optional[torch.LongTensor] = None,
171+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
172+
output_attentions: bool = False,
173+
use_cache: bool = False,
174+
cache_position: Optional[torch.LongTensor] = None,
175+
residual: Optional[torch.Tensor] = None,
176+
**kwargs,
177+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178+
"""
179+
Args:
180+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
181+
attention_mask (`torch.FloatTensor`, *optional*):
182+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
183+
query_sequence_length, key_sequence_length)` if default attention is used.
184+
output_attentions (`bool`, *optional*):
185+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
186+
returned tensors for more detail.
187+
use_cache (`bool`, *optional*):
188+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
189+
(see `past_key_values`).
190+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
191+
residual (`torch.Tensor`): residual tensor to the layer of shape `
192+
"""
193+
bsz, seq_len, _ = hidden_states.size()
194+
195+
query = self.q_proj(hidden_states)
196+
key = self.k_proj(hidden_states)
197+
value = self.v_proj(hidden_states)
224198

199+
kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len
200+
201+
query = query.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
202+
key = key.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
203+
value = value.view(bsz, seq_len, self.num_key_value_heads, self.head_dim)
204+
# Use ipex op to rotary position embedding more efficient.
205+
key = self.ipex_rope(
206+
key,
207+
position_ids,
208+
self.num_key_value_heads,
209+
self.head_dim,
210+
self.head_dim // 2,
211+
self.head_dim,
212+
kv_seq_len,
213+
)
214+
query = self.ipex_rope(
215+
query,
216+
position_ids,
217+
self.num_heads,
218+
self.head_dim,
219+
self.head_dim // 2,
220+
self.head_dim,
221+
kv_seq_len,
222+
)
223+
224+
if use_cache:
225+
# This ipex op pre-allocates buffers for past_key_values and use beam index history
226+
# which to decide which beam should be used to make attention scale dot more efficient.
227+
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
228+
query,
229+
key,
230+
value,
231+
math.sqrt(self.head_dim),
232+
past_key_value,
233+
None,
234+
attention_mask,
235+
)
236+
else:
237+
value_states = value.transpose(1, 2)
238+
query_states = query.transpose(1, 2)
239+
key_states = key.transpose(1, 2)
240+
kv_seq_len = key_states.shape[-2]
241+
242+
past_key_value = None
243+
# repeat k/v heads if n_kv_heads < n_heads
244+
key_states = repeat_kv(key_states, self.num_key_value_groups)
245+
value_states = repeat_kv(value_states, self.num_key_value_groups)
246+
247+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
248+
249+
if attention_mask is not None:
250+
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
251+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
252+
253+
# upcast attention to fp32
254+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
255+
attn_output = torch.matmul(attn_weights, value_states)
256+
257+
attn_output = attn_output.transpose(1, 2)
258+
attn_output = attn_output.reshape(bsz, seq_len, self.hidden_size)
259+
260+
if hasattr(self, "mha_linear_add"):
261+
attn_output = self.mha_linear_add(attn_output, residual)
262+
else:
263+
attn_output = self.o_proj(attn_output)
264+
attn_output = residual + attn_output
265+
266+
if not output_attentions:
267+
attn_weights = None
268+
269+
return attn_output, past_key_value, attn_weights
270+
271+
272+
class _IPEXLlamaMLP(nn.Module):
273+
def __init__(self, module, config, distributed=False) -> None:
274+
if is_ipex_version("<", "2.3.0"):
275+
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
225276
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
226277

227278
super().__init__()
@@ -231,15 +282,47 @@ def __init__(self, module, config, distributed=False):
231282
if k.startswith("__") or k.startswith("forward"):
232283
continue
233284
setattr(self.__class__, k, getattr(module.__class__, k))
285+
self.config = config
234286
self.distributed = distributed
235287
if not self.distributed:
236-
self.mha_linear_add = LinearAdd(module.self_attn.o_proj)
237-
self.mlp_linear_add = LinearAdd(module.mlp.down_proj)
238-
del self.__dict__["_modules"]["self_attn"].o_proj
239-
del self.__dict__["_modules"]["mlp"].down_proj
240-
self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
241-
del self.__dict__["_modules"]["mlp"].gate_proj
242-
del self.__dict__["_modules"]["mlp"].up_proj
288+
self.mlp_linear_add = LinearAdd(module.down_proj)
289+
del self.__dict__["_modules"]["down_proj"]
290+
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
291+
del self.__dict__["_modules"]["gate_proj"]
292+
del self.__dict__["_modules"]["up_proj"]
293+
294+
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
295+
"""
296+
Args:
297+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
298+
"""
299+
if hasattr(self, "linear_silu_mul"):
300+
mlp_gate = self.linear_silu_mul(hidden_states)
301+
if hasattr(self, "mlp_linear_add"):
302+
hidden_states = self.mlp_linear_add(mlp_gate, residual)
303+
else:
304+
hidden_states = self.down_proj(mlp_gate)
305+
hidden_states = residual + hidden_states
306+
else:
307+
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
308+
hidden_states = residual + hidden_states
309+
310+
return hidden_states
311+
312+
313+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
314+
class _IPEXLlamaDecoderLayerRef(nn.Module):
315+
def __init__(self, module, config, distributed=False):
316+
super().__init__()
317+
for k, v in module.__dict__.items():
318+
setattr(self, k, v)
319+
for k, v in module.__class__.__dict__.items():
320+
if k.startswith("__") or k.startswith("forward"):
321+
continue
322+
setattr(self.__class__, k, getattr(module.__class__, k))
323+
self.distributed = distributed
324+
self.self_attn = _IPEXLlamaAttentionRef(module.self_attn, config, distributed)
325+
self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)
243326

244327
def forward(
245328
self,
@@ -270,34 +353,22 @@ def forward(
270353
hidden_states = self.input_layernorm(hidden_states)
271354

272355
# Self Attention
273-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
356+
hidden_states, present_key_value, self_attn_weights = self.self_attn(
274357
hidden_states=hidden_states,
275358
attention_mask=attention_mask,
276359
position_ids=position_ids,
277360
past_key_value=past_key_value,
278361
output_attentions=output_attentions,
279362
use_cache=use_cache,
363+
cache_position=None,
364+
residual=residual,
365+
**kwargs,
280366
)
281-
if hasattr(self, "mha_linear_add"):
282-
hidden_states = self.mha_linear_add(hidden_states, residual)
283-
else:
284-
hidden_states = self.self_attn.o_proj(hidden_states)
285-
hidden_states = residual + hidden_states
286367

287368
# Fully Connected
288369
residual = hidden_states
289370
hidden_states = self.post_attention_layernorm(hidden_states)
290-
291-
if hasattr(self, "linear_silu_mul"):
292-
mlp_gate = self.linear_silu_mul(hidden_states)
293-
if hasattr(self, "mlp_linear_add"):
294-
hidden_states = self.mlp_linear_add(mlp_gate, residual)
295-
else:
296-
hidden_states = self.mlp.down_proj(mlp_gate)
297-
hidden_states = residual + hidden_states
298-
else:
299-
hidden_states = self.mlp(hidden_states)
300-
hidden_states = residual + hidden_states
371+
hidden_states = self.mlp(hidden_states, residual, **kwargs)
301372

302373
outputs = (hidden_states,)
303374

0 commit comments

Comments
 (0)