@@ -29,90 +29,6 @@ def _llama_layer_norm_forward(self, hidden_states):
29
29
return torch .ops .torch_ipex .rmsnorm (hidden_states , self .weight , self .variance_epsilon )
30
30
31
31
32
- # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
33
- def _llama_attn_forward (
34
- self ,
35
- hidden_states : torch .Tensor ,
36
- attention_mask : Optional [torch .Tensor ] = None ,
37
- position_ids : Optional [torch .LongTensor ] = None ,
38
- past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
39
- output_attentions : bool = False ,
40
- use_cache : bool = False ,
41
- ** kwargs ,
42
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
43
- bsz , q_len , _ = hidden_states .size ()
44
-
45
- query = self .q_proj (hidden_states )
46
- key = self .k_proj (hidden_states )
47
- value = self .v_proj (hidden_states )
48
-
49
- kv_seq_len = q_len + past_key_value [0 ].size (- 2 ) if past_key_value is not None else q_len
50
-
51
- query = query .view (bsz , q_len , self .num_heads , self .head_dim )
52
- key = key .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
53
- value = value .view (bsz , q_len , self .num_key_value_heads , self .head_dim )
54
- # Use ipex op to rotary position embedding more efficient.
55
- key = self .ipex_rope (
56
- key ,
57
- position_ids ,
58
- self .num_key_value_heads ,
59
- self .head_dim ,
60
- self .head_dim // 2 ,
61
- self .head_dim ,
62
- kv_seq_len ,
63
- )
64
- query = self .ipex_rope (
65
- query ,
66
- position_ids ,
67
- self .num_heads ,
68
- self .head_dim ,
69
- self .head_dim // 2 ,
70
- self .head_dim ,
71
- kv_seq_len ,
72
- )
73
-
74
- if use_cache :
75
- # This ipex op pre-allocates buffers for past_key_values and use beam index history
76
- # which to decide which beam should be used to make attention scale dot more efficient.
77
- (attn_output , attn_weights , past_key_value ) = self .ipex_scale_dot_product (
78
- query ,
79
- key ,
80
- value ,
81
- math .sqrt (self .head_dim ),
82
- past_key_value ,
83
- None ,
84
- attention_mask ,
85
- )
86
- else :
87
- value_states = value .transpose (1 , 2 )
88
- query_states = query .transpose (1 , 2 )
89
- key_states = key .transpose (1 , 2 )
90
- kv_seq_len = key_states .shape [- 2 ]
91
-
92
- past_key_value = None
93
- # repeat k/v heads if n_kv_heads < n_heads
94
- key_states = repeat_kv (key_states , self .num_key_value_groups )
95
- value_states = repeat_kv (value_states , self .num_key_value_groups )
96
-
97
- attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
98
-
99
- if attention_mask is not None :
100
- attn_weights = torch .tensor (attn_weights ) + torch .tensor (attention_mask )
101
- attn_weights = torch .max (attn_weights , torch .tensor (torch .finfo (attn_weights .dtype ).min ))
102
-
103
- # upcast attention to fp32
104
- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
105
- attn_output = torch .matmul (attn_weights , value_states )
106
-
107
- attn_output = attn_output .transpose (1 , 2 )
108
- attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
109
-
110
- if not output_attentions :
111
- attn_weights = None
112
-
113
- return attn_output , attn_weights , past_key_value
114
-
115
-
116
32
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
117
33
def _llama_model_forward (
118
34
self ,
@@ -216,12 +132,147 @@ def _llama_model_forward(
216
132
)
217
133
218
134
219
- # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
220
- class _IPEXLlamaDecoderLayerRef (nn .Module ):
221
- def __init__ (self , module , config , distributed = False ):
135
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
136
+ class _IPEXLlamaAttentionRef (nn .Module ):
137
+ def __init__ (self , module , config , distributed = False ) -> None :
222
138
if is_ipex_version ("<" , "2.3.0" ):
223
- raise ImportError ("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd" )
139
+ raise ImportError (
140
+ "Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
141
+ )
142
+ from intel_extension_for_pytorch .llm .modules import IndirectAccessKVCacheAttention , LinearAdd , RotaryEmbedding
143
+
144
+ super ().__init__ ()
145
+ for k , v in module .__dict__ .items ():
146
+ setattr (self , k , v )
147
+ for k , v in module .__class__ .__dict__ .items ():
148
+ if k .startswith ("__" ) or k .startswith ("forward" ):
149
+ continue
150
+ setattr (self .__class__ , k , getattr (module .__class__ , k ))
151
+ self .config = config
152
+ self .distributed = distributed
153
+ if not self .distributed :
154
+ self .mha_linear_add = LinearAdd (self .o_proj )
155
+ del self .__dict__ ["_modules" ]["o_proj" ]
156
+ self .ipex_scale_dot_product = IndirectAccessKVCacheAttention (
157
+ text_max_length = module .config .max_position_embeddings
158
+ )
159
+ self .ipex_rope = RotaryEmbedding (
160
+ module .config .max_position_embeddings ,
161
+ module .config .hidden_size // module .config .num_attention_heads ,
162
+ module .config .rope_theta ,
163
+ module .config .architectures [0 ],
164
+ )
165
+
166
+ def forward (
167
+ self ,
168
+ hidden_states : torch .Tensor ,
169
+ attention_mask : Optional [torch .Tensor ] = None ,
170
+ position_ids : Optional [torch .LongTensor ] = None ,
171
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
172
+ output_attentions : bool = False ,
173
+ use_cache : bool = False ,
174
+ cache_position : Optional [torch .LongTensor ] = None ,
175
+ residual : Optional [torch .Tensor ] = None ,
176
+ ** kwargs ,
177
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
178
+ """
179
+ Args:
180
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
181
+ attention_mask (`torch.FloatTensor`, *optional*):
182
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
183
+ query_sequence_length, key_sequence_length)` if default attention is used.
184
+ output_attentions (`bool`, *optional*):
185
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
186
+ returned tensors for more detail.
187
+ use_cache (`bool`, *optional*):
188
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
189
+ (see `past_key_values`).
190
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
191
+ residual (`torch.Tensor`): residual tensor to the layer of shape `
192
+ """
193
+ bsz , seq_len , _ = hidden_states .size ()
194
+
195
+ query = self .q_proj (hidden_states )
196
+ key = self .k_proj (hidden_states )
197
+ value = self .v_proj (hidden_states )
224
198
199
+ kv_seq_len = seq_len + past_key_value [0 ].size (- 2 ) if past_key_value is not None else seq_len
200
+
201
+ query = query .view (bsz , seq_len , self .num_key_value_heads , self .head_dim )
202
+ key = key .view (bsz , seq_len , self .num_key_value_heads , self .head_dim )
203
+ value = value .view (bsz , seq_len , self .num_key_value_heads , self .head_dim )
204
+ # Use ipex op to rotary position embedding more efficient.
205
+ key = self .ipex_rope (
206
+ key ,
207
+ position_ids ,
208
+ self .num_key_value_heads ,
209
+ self .head_dim ,
210
+ self .head_dim // 2 ,
211
+ self .head_dim ,
212
+ kv_seq_len ,
213
+ )
214
+ query = self .ipex_rope (
215
+ query ,
216
+ position_ids ,
217
+ self .num_heads ,
218
+ self .head_dim ,
219
+ self .head_dim // 2 ,
220
+ self .head_dim ,
221
+ kv_seq_len ,
222
+ )
223
+
224
+ if use_cache :
225
+ # This ipex op pre-allocates buffers for past_key_values and use beam index history
226
+ # which to decide which beam should be used to make attention scale dot more efficient.
227
+ (attn_output , attn_weights , past_key_value ) = self .ipex_scale_dot_product (
228
+ query ,
229
+ key ,
230
+ value ,
231
+ math .sqrt (self .head_dim ),
232
+ past_key_value ,
233
+ None ,
234
+ attention_mask ,
235
+ )
236
+ else :
237
+ value_states = value .transpose (1 , 2 )
238
+ query_states = query .transpose (1 , 2 )
239
+ key_states = key .transpose (1 , 2 )
240
+ kv_seq_len = key_states .shape [- 2 ]
241
+
242
+ past_key_value = None
243
+ # repeat k/v heads if n_kv_heads < n_heads
244
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
245
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
246
+
247
+ attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
248
+
249
+ if attention_mask is not None :
250
+ attn_weights = torch .tensor (attn_weights ) + torch .tensor (attention_mask )
251
+ attn_weights = torch .max (attn_weights , torch .tensor (torch .finfo (attn_weights .dtype ).min ))
252
+
253
+ # upcast attention to fp32
254
+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
255
+ attn_output = torch .matmul (attn_weights , value_states )
256
+
257
+ attn_output = attn_output .transpose (1 , 2 )
258
+ attn_output = attn_output .reshape (bsz , seq_len , self .hidden_size )
259
+
260
+ if hasattr (self , "mha_linear_add" ):
261
+ attn_output = self .mha_linear_add (attn_output , residual )
262
+ else :
263
+ attn_output = self .o_proj (attn_output )
264
+ attn_output = residual + attn_output
265
+
266
+ if not output_attentions :
267
+ attn_weights = None
268
+
269
+ return attn_output , past_key_value , attn_weights
270
+
271
+
272
+ class _IPEXLlamaMLP (nn .Module ):
273
+ def __init__ (self , module , config , distributed = False ) -> None :
274
+ if is_ipex_version ("<" , "2.3.0" ):
275
+ raise ImportError ("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd" )
225
276
from intel_extension_for_pytorch .llm .modules import Linear2SiluMul , LinearAdd
226
277
227
278
super ().__init__ ()
@@ -231,15 +282,47 @@ def __init__(self, module, config, distributed=False):
231
282
if k .startswith ("__" ) or k .startswith ("forward" ):
232
283
continue
233
284
setattr (self .__class__ , k , getattr (module .__class__ , k ))
285
+ self .config = config
234
286
self .distributed = distributed
235
287
if not self .distributed :
236
- self .mha_linear_add = LinearAdd (module .self_attn .o_proj )
237
- self .mlp_linear_add = LinearAdd (module .mlp .down_proj )
238
- del self .__dict__ ["_modules" ]["self_attn" ].o_proj
239
- del self .__dict__ ["_modules" ]["mlp" ].down_proj
240
- self .linear_silu_mul = Linear2SiluMul (module .mlp .gate_proj , module .mlp .up_proj )
241
- del self .__dict__ ["_modules" ]["mlp" ].gate_proj
242
- del self .__dict__ ["_modules" ]["mlp" ].up_proj
288
+ self .mlp_linear_add = LinearAdd (module .down_proj )
289
+ del self .__dict__ ["_modules" ]["down_proj" ]
290
+ self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
291
+ del self .__dict__ ["_modules" ]["gate_proj" ]
292
+ del self .__dict__ ["_modules" ]["up_proj" ]
293
+
294
+ def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor = None , ** kwargs ):
295
+ """
296
+ Args:
297
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
298
+ """
299
+ if hasattr (self , "linear_silu_mul" ):
300
+ mlp_gate = self .linear_silu_mul (hidden_states )
301
+ if hasattr (self , "mlp_linear_add" ):
302
+ hidden_states = self .mlp_linear_add (mlp_gate , residual )
303
+ else :
304
+ hidden_states = self .down_proj (mlp_gate )
305
+ hidden_states = residual + hidden_states
306
+ else :
307
+ hidden_states = self .down_proj (self .act_fn (self .gate_proj (hidden_states )) * self .up_proj (hidden_states ))
308
+ hidden_states = residual + hidden_states
309
+
310
+ return hidden_states
311
+
312
+
313
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
314
+ class _IPEXLlamaDecoderLayerRef (nn .Module ):
315
+ def __init__ (self , module , config , distributed = False ):
316
+ super ().__init__ ()
317
+ for k , v in module .__dict__ .items ():
318
+ setattr (self , k , v )
319
+ for k , v in module .__class__ .__dict__ .items ():
320
+ if k .startswith ("__" ) or k .startswith ("forward" ):
321
+ continue
322
+ setattr (self .__class__ , k , getattr (module .__class__ , k ))
323
+ self .distributed = distributed
324
+ self .self_attn = _IPEXLlamaAttentionRef (module .self_attn , config , distributed )
325
+ self .mlp = _IPEXLlamaMLP (module .mlp , config , distributed )
243
326
244
327
def forward (
245
328
self ,
@@ -270,34 +353,22 @@ def forward(
270
353
hidden_states = self .input_layernorm (hidden_states )
271
354
272
355
# Self Attention
273
- hidden_states , self_attn_weights , present_key_value = self .self_attn (
356
+ hidden_states , present_key_value , self_attn_weights = self .self_attn (
274
357
hidden_states = hidden_states ,
275
358
attention_mask = attention_mask ,
276
359
position_ids = position_ids ,
277
360
past_key_value = past_key_value ,
278
361
output_attentions = output_attentions ,
279
362
use_cache = use_cache ,
363
+ cache_position = None ,
364
+ residual = residual ,
365
+ ** kwargs ,
280
366
)
281
- if hasattr (self , "mha_linear_add" ):
282
- hidden_states = self .mha_linear_add (hidden_states , residual )
283
- else :
284
- hidden_states = self .self_attn .o_proj (hidden_states )
285
- hidden_states = residual + hidden_states
286
367
287
368
# Fully Connected
288
369
residual = hidden_states
289
370
hidden_states = self .post_attention_layernorm (hidden_states )
290
-
291
- if hasattr (self , "linear_silu_mul" ):
292
- mlp_gate = self .linear_silu_mul (hidden_states )
293
- if hasattr (self , "mlp_linear_add" ):
294
- hidden_states = self .mlp_linear_add (mlp_gate , residual )
295
- else :
296
- hidden_states = self .mlp .down_proj (mlp_gate )
297
- hidden_states = residual + hidden_states
298
- else :
299
- hidden_states = self .mlp (hidden_states )
300
- hidden_states = residual + hidden_states
371
+ hidden_states = self .mlp (hidden_states , residual , ** kwargs )
301
372
302
373
outputs = (hidden_states ,)
303
374
0 commit comments