Skip to content

Commit 62f570f

Browse files
authored
Fix compatibility for latest transformers release (#570)
* fix compatibility for latest transformers release * update setup * update setup * fix test input size * fix prepare generation for llama models
1 parent b3f9711 commit 62f570f

File tree

6 files changed

+126
-17
lines changed

6 files changed

+126
-17
lines changed

optimum/intel/ipex/modeling_base.py

+69-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from optimum.utils import NormalizedConfigManager
4747

4848
from ..generation.modeling import jit_trace, prepare_jit_inputs
49-
from ..utils.import_utils import is_torch_version
49+
from ..utils.import_utils import is_torch_version, is_transformers_version
5050
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
5151

5252

@@ -326,7 +326,8 @@ def __init__(
326326
# Perform the initial warmup at the end of __init__
327327
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
328328

329-
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
329+
model_type = config.model_type.replace("_", "-")
330+
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
330331
self.model_dtype = kwargs.get("model_dtype", self.dtype)
331332
self.use_cache = "past_key_values" in self.input_names
332333

@@ -339,6 +340,7 @@ def __init__(
339340
)
340341
config.is_decoder = True
341342
config.is_encoder_decoder = False
343+
342344
self.generation_config = GenerationConfig.from_model_config(config)
343345
try:
344346
self.model_cls = get_class_from_dynamic_module(
@@ -347,7 +349,12 @@ def __init__(
347349
except AttributeError:
348350
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
349351
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
350-
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
352+
353+
if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}:
354+
self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
355+
else:
356+
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
357+
351358
if hasattr(self.model_cls, "_convert_to_standard_cache"):
352359
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
353360
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
@@ -430,3 +437,62 @@ def forward(
430437
past_key_values = outputs["past_key_values"] if self.use_cache else None
431438

432439
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
440+
441+
442+
def _prepare_inputs_for_generation_for_llama(
443+
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
444+
):
445+
from transformers.cache_utils import Cache
446+
447+
if past_key_values is not None:
448+
if isinstance(past_key_values, Cache):
449+
cache_length = past_key_values.get_seq_length()
450+
past_length = past_key_values.seen_tokens
451+
max_cache_length = past_key_values.get_max_length()
452+
else:
453+
cache_length = past_length = past_key_values[0][0].shape[2]
454+
max_cache_length = None
455+
456+
# Keep only the unprocessed tokens:
457+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
458+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
459+
# input)
460+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
461+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
462+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
463+
# input_ids based on the past_length.
464+
elif past_length < input_ids.shape[1]:
465+
input_ids = input_ids[:, past_length:]
466+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
467+
468+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
469+
if (
470+
max_cache_length is not None
471+
and attention_mask is not None
472+
and cache_length + input_ids.shape[1] > max_cache_length
473+
):
474+
attention_mask = attention_mask[:, -max_cache_length:]
475+
476+
position_ids = kwargs.get("position_ids", None)
477+
if attention_mask is not None and position_ids is None:
478+
# create position_ids on the fly for batch generation
479+
position_ids = attention_mask.long().cumsum(-1) - 1
480+
position_ids.masked_fill_(attention_mask == 0, 1)
481+
if past_key_values:
482+
position_ids = position_ids[:, -input_ids.shape[1] :]
483+
484+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
485+
if inputs_embeds is not None and past_key_values is None:
486+
model_inputs = {"inputs_embeds": inputs_embeds}
487+
else:
488+
model_inputs = {"input_ids": input_ids}
489+
490+
model_inputs.update(
491+
{
492+
"position_ids": position_ids,
493+
"past_key_values": past_key_values,
494+
"use_cache": kwargs.get("use_cache"),
495+
"attention_mask": attention_mask,
496+
}
497+
)
498+
return model_inputs

optimum/intel/neural_compressor/trainer.py

+39
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,42 @@ def get_model_sparsity(self):
941941
if self._compression_manager is not None:
942942
sparsity = self._compression_manager.model.report_sparsity()[-1]
943943
return sparsity
944+
945+
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
946+
# TODO : can be removed once transformers >= v4.38.0
947+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
948+
if is_torch_tpu_available():
949+
xm.mark_step()
950+
951+
logs: Dict[str, float] = {}
952+
953+
# all_gather + mean() to get average loss over all processes
954+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
955+
956+
# reset tr_loss to zero
957+
tr_loss -= tr_loss
958+
959+
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
960+
logs["learning_rate"] = self._get_learning_rate()
961+
962+
self._total_loss_scalar += tr_loss_scalar
963+
self._globalstep_last_logged = self.state.global_step
964+
self.store_flos()
965+
966+
self.log(logs)
967+
968+
metrics = None
969+
if self.control.should_evaluate:
970+
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
971+
self._report_to_hp_search(trial, self.state.global_step, metrics)
972+
973+
# Run delayed LR scheduler now that metrics are populated
974+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
975+
metric_to_check = self.args.metric_for_best_model
976+
if not metric_to_check.startswith("eval_"):
977+
metric_to_check = f"eval_{metric_to_check}"
978+
self.lr_scheduler.step(metrics[metric_to_check])
979+
980+
if self.control.should_save:
981+
self._save_checkpoint(model, trial, metrics=metrics)
982+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

setup.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
INSTALL_REQUIRE = [
1515
"torch>=1.11",
16-
"optimum>=1.17.0",
17-
"transformers>=4.29.0,<4.38.0",
16+
"optimum~=1.17",
17+
"transformers>=4.36.0,<4.39.0",
1818
"datasets>=1.4.0",
1919
"sentencepiece",
2020
"scipy",
@@ -43,14 +43,11 @@
4343
"neural-compressor>=2.2.0",
4444
"onnx",
4545
"onnxruntime<1.15.0",
46-
"transformers>=4.34.0",
4746
],
4847
"openvino": [
4948
"openvino>=2023.3",
5049
"onnx",
5150
"onnxruntime",
52-
"transformers>=4.36.0",
53-
"optimum>=1.16.1",
5451
],
5552
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
5653
"nncf": ["nncf>=2.8.1"],

tests/ipex/test_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_text_generation_pipeline_inference(self, model_arch):
115115
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False)
116116
model = model.eval()
117117
tokenizer = AutoTokenizer.from_pretrained(model_id)
118-
inputs = "DeepSpeed is a machine learning framework for deep neural networks and deep reinforcement learning. It is written in C++ and is available for Linux, Mac OS X,"
118+
inputs = "This is a simple input"
119119
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
120120
with torch.inference_mode():
121121
output = text_generator(inputs)

tests/ipex/test_modeling.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
6868
"levit": "hf-internal-testing/tiny-random-LevitModel",
6969
"llama": "fxmarty/tiny-llama-fast-tokenizer",
70-
"opt": "hf-internal-testing/tiny-random-OPTModel",
7170
"marian": "sshleifer/tiny-marian-en-de",
7271
"mbart": "hf-internal-testing/tiny-random-mbart",
7372
"mistral": "echarlaix/tiny-random-mistral",
@@ -76,6 +75,8 @@
7675
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
7776
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
7877
"mt5": "stas/mt5-tiny-random",
78+
"opt": "hf-internal-testing/tiny-random-OPTModel",
79+
"phi": "hf-internal-testing/tiny-random-PhiForCausalLM",
7980
"resnet": "hf-internal-testing/tiny-random-resnet",
8081
"roberta": "hf-internal-testing/tiny-random-roberta",
8182
"roformer": "hf-internal-testing/tiny-random-roformer",
@@ -199,7 +200,7 @@ def test_pipeline(self, model_arch):
199200
class IPEXModelForCausalLMTest(unittest.TestCase):
200201
SUPPORTED_ARCHITECTURES = (
201202
"bart",
202-
# "gpt_bigcode",
203+
"gpt_bigcode",
203204
"blenderbot",
204205
"blenderbot-small",
205206
"bloom",
@@ -208,8 +209,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
208209
"gpt_neo",
209210
"gpt_neox",
210211
"llama",
211-
# "mistral",
212-
# "mpt",
212+
"mistral",
213+
# "phi",
214+
"mpt",
213215
"opt",
214216
)
215217
GENERATION_LENGTH = 100

tests/openvino/test_modeling.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
483483
"gpt_neo",
484484
"gpt_neox",
485485
"llama",
486-
"llama_gptq",
486+
# "llama_gptq",
487487
"marian",
488488
"mistral",
489489
"mpt",
@@ -504,7 +504,7 @@ def test_compare_to_transformers(self, model_arch):
504504
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
505505
self.assertIsInstance(ov_model.config, PretrainedConfig)
506506
self.assertTrue(ov_model.use_cache)
507-
self.assertEqual(ov_model.stateful, self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode")
507+
508508
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
509509
tokenizer = AutoTokenizer.from_pretrained(model_id)
510510
tokens = tokenizer(
@@ -520,10 +520,15 @@ def test_compare_to_transformers(self, model_arch):
520520
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
521521
self.assertTrue("past_key_values" in ov_outputs)
522522
self.assertIsInstance(ov_outputs.past_key_values, tuple)
523-
if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode":
523+
524+
is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL
525+
self.assertEqual(ov_model.stateful, is_stateful)
526+
if is_stateful:
524527
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
528+
525529
with torch.no_grad():
526530
transformers_outputs = transformers_model(**tokens)
531+
527532
# Compare tensor outputs
528533
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
529534
del transformers_model
@@ -540,7 +545,7 @@ def test_pipeline(self, model_arch):
540545
model.half()
541546
model.compile()
542547
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
543-
outputs = pipe("This is a sample", max_length=10)
548+
outputs = pipe("This is a sample", max_length=20)
544549
self.assertEqual(pipe.device, model.device)
545550
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
546551
del pipe

0 commit comments

Comments
 (0)