24
24
from optimum .intel .utils .import_utils import is_ipex_version
25
25
26
26
27
- def llama_layer_norm_forward (self , hidden_states ):
27
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
28
+ def _llama_layer_norm_forward (self , hidden_states ):
28
29
return torch .ops .torch_ipex .rmsnorm (hidden_states , self .weight , self .variance_epsilon )
29
30
30
31
31
- def llama_attn_forward (
32
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
33
+ def _llama_attn_forward (
32
34
self ,
33
35
hidden_states : torch .Tensor ,
34
36
attention_mask : Optional [torch .Tensor ] = None ,
@@ -111,7 +113,8 @@ def llama_attn_forward(
111
113
return attn_output , attn_weights , past_key_value
112
114
113
115
114
- def llama_model_forward (
116
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
117
+ def _llama_model_forward (
115
118
self ,
116
119
input_ids : torch .LongTensor = None ,
117
120
attention_mask : Optional [torch .Tensor ] = None ,
@@ -168,9 +171,6 @@ def llama_model_forward(
168
171
# embed positions
169
172
hidden_states = inputs_embeds
170
173
171
- if self .gradient_checkpointing and self .training :
172
- use_cache = False
173
-
174
174
# decoder layers
175
175
all_hidden_states = () if output_hidden_states else None
176
176
all_self_attns = () if output_attentions else None
@@ -182,25 +182,14 @@ def llama_model_forward(
182
182
183
183
past_key_value = past_key_values [idx ] if past_key_values is not None else None
184
184
185
- if self .gradient_checkpointing and self .training :
186
- layer_outputs = self ._gradient_checkpointing_func (
187
- decoder_layer .__call__ ,
188
- hidden_states ,
189
- attention_mask ,
190
- position_ids ,
191
- past_key_value ,
192
- output_attentions ,
193
- use_cache ,
194
- )
195
- else :
196
- layer_outputs = decoder_layer (
197
- hidden_states ,
198
- attention_mask = attention_mask ,
199
- position_ids = position_ids ,
200
- past_key_value = past_key_value ,
201
- output_attentions = output_attentions ,
202
- use_cache = use_cache ,
203
- )
185
+ layer_outputs = decoder_layer (
186
+ hidden_states ,
187
+ attention_mask = attention_mask ,
188
+ position_ids = position_ids ,
189
+ past_key_value = past_key_value ,
190
+ output_attentions = output_attentions ,
191
+ use_cache = use_cache ,
192
+ )
204
193
205
194
hidden_states = layer_outputs [0 ]
206
195
@@ -227,6 +216,7 @@ def llama_model_forward(
227
216
)
228
217
229
218
219
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
230
220
class _IPEXLlamaDecoderLayerRef (nn .Module ):
231
221
def __init__ (self , module , config , distributed = False ):
232
222
if is_ipex_version ("<=" , "2.3.0" ):
0 commit comments