Skip to content

Commit 4c3335b

Browse files
committed
fix comments
1 parent e03259c commit 4c3335b

File tree

4 files changed

+34
-48
lines changed

4 files changed

+34
-48
lines changed

optimum/exporters/ipex/model_patcher.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
from .modeling_utils import (
2626
_IPEXLlamaDecoderLayerRef,
27-
llama_attn_forward,
28-
llama_layer_norm_forward,
29-
llama_model_forward,
27+
_llama_attn_forward,
28+
_llama_layer_norm_forward,
29+
_llama_model_forward,
3030
)
3131

3232

@@ -77,9 +77,9 @@ def _patch_llama_model(model):
7777
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
7878
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
7979

80-
convert_functions(model, LlamaModel, "forward", llama_model_forward)
81-
convert_functions(model, LlamaAttention, "forward", llama_attn_forward)
82-
convert_functions(model, LlamaRMSNorm, "forward", llama_layer_norm_forward)
80+
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
81+
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
82+
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
8383

8484
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
8585
return model

optimum/exporters/ipex/modeling_utils.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
from optimum.intel.utils.import_utils import is_ipex_version
2525

2626

27-
def llama_layer_norm_forward(self, hidden_states):
27+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
28+
def _llama_layer_norm_forward(self, hidden_states):
2829
return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
2930

3031

31-
def llama_attn_forward(
32+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
33+
def _llama_attn_forward(
3234
self,
3335
hidden_states: torch.Tensor,
3436
attention_mask: Optional[torch.Tensor] = None,
@@ -111,7 +113,8 @@ def llama_attn_forward(
111113
return attn_output, attn_weights, past_key_value
112114

113115

114-
def llama_model_forward(
116+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
117+
def _llama_model_forward(
115118
self,
116119
input_ids: torch.LongTensor = None,
117120
attention_mask: Optional[torch.Tensor] = None,
@@ -168,9 +171,6 @@ def llama_model_forward(
168171
# embed positions
169172
hidden_states = inputs_embeds
170173

171-
if self.gradient_checkpointing and self.training:
172-
use_cache = False
173-
174174
# decoder layers
175175
all_hidden_states = () if output_hidden_states else None
176176
all_self_attns = () if output_attentions else None
@@ -182,25 +182,14 @@ def llama_model_forward(
182182

183183
past_key_value = past_key_values[idx] if past_key_values is not None else None
184184

185-
if self.gradient_checkpointing and self.training:
186-
layer_outputs = self._gradient_checkpointing_func(
187-
decoder_layer.__call__,
188-
hidden_states,
189-
attention_mask,
190-
position_ids,
191-
past_key_value,
192-
output_attentions,
193-
use_cache,
194-
)
195-
else:
196-
layer_outputs = decoder_layer(
197-
hidden_states,
198-
attention_mask=attention_mask,
199-
position_ids=position_ids,
200-
past_key_value=past_key_value,
201-
output_attentions=output_attentions,
202-
use_cache=use_cache,
203-
)
185+
layer_outputs = decoder_layer(
186+
hidden_states,
187+
attention_mask=attention_mask,
188+
position_ids=position_ids,
189+
past_key_value=past_key_value,
190+
output_attentions=output_attentions,
191+
use_cache=use_cache,
192+
)
204193

205194
hidden_states = layer_outputs[0]
206195

@@ -227,6 +216,7 @@ def llama_model_forward(
227216
)
228217

229218

219+
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
230220
class _IPEXLlamaDecoderLayerRef(nn.Module):
231221
def __init__(self, module, config, distributed=False):
232222
if is_ipex_version("<=", "2.3.0"):

optimum/intel/ipex/modeling_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
def _is_patched_with_ipex(model, task):
6363
if is_ipex_version("<=", "2.3.0"):
6464
return False
65+
6566
if isinstance(model, torch.jit.ScriptModule):
6667
for node in model.graph.nodes():
6768
# Jit will record the codes position so we can check if the node use ipex exporter.

tests/ipex/test_modeling.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -264,31 +264,26 @@ def test_pipeline(self, model_arch):
264264
{
265265
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
266266
"use_cache": [True, False],
267-
"num_beams": [1, 4],
268-
"batch_size": [1, 4],
269267
}
270268
)
271269
)
272270
@unittest.skipIf(is_ipex_version("<=", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
273-
def test_ipex_patching(self, test_name, model_arch, use_cache, num_beams, batch_size):
271+
def test_ipex_patching_generation(self, test_name, model_arch, use_cache):
274272
model_id = MODEL_NAMES[model_arch]
275273
set_seed(SEED)
276-
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
277-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
274+
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache)
278275
tokenizer = AutoTokenizer.from_pretrained(model_id)
279276
tokenizer.pad_token = tokenizer.eos_token
280-
texts = ["This is a sample"] * batch_size
281-
tokens = tokenizer(texts, padding=True, return_tensors="pt")
282-
generation_config = GenerationConfig(
283-
max_new_tokens=16, num_beams=num_beams, do_sample=False, use_cache=use_cache
284-
)
285-
outputs = model.generate(**tokens, generation_config=generation_config)
286-
with torch.no_grad():
287-
transformers_outputs = transformers_model(**tokens)
288-
289-
self.assertIsInstance(outputs.logits, torch.Tensor)
290-
# Compare tensor outputs
291-
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
277+
# Test with batch_size is 1 and 2.
278+
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
279+
for text in texts:
280+
tokens = tokenizer(text, padding=True, return_tensors="pt")
281+
for num_beams in [1, 4]:
282+
generation_config = GenerationConfig(
283+
max_new_tokens=4, num_beams=num_beams, do_sample=True, top_p=0.9, top_k=5
284+
)
285+
outputs = model.generate(**tokens, generation_config=generation_config)
286+
self.assertIsInstance(outputs, torch.Tensor)
292287

293288
def test_compare_with_and_without_past_key_values(self):
294289
model_id = "echarlaix/tiny-random-gpt2-torchscript"

0 commit comments

Comments
 (0)