Skip to content

Commit a76be08

Browse files
authored
Enable Text2text task on ipex (#1054)
* enable IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * set static cache Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tests for IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add docs Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix readme Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * refactor compile Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix ruff check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix opt tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 3c229fc commit a76be08

File tree

10 files changed

+330
-31
lines changed

10 files changed

+330
-31
lines changed

docs/source/ipex/inference.mdx

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m
1414

1515
## Loading
1616

17-
You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
17+
You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators.
1818
For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22.
1919

2020
```diff
@@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au
4343
| `IPEXModelForMaskedLM` | `fill-mask` |
4444
| `IPEXModelForAudioClassification` | `audio-classification` |
4545
| `IPEXModelForCausalLM` | `text-generation` |
46+
| `IPEXModelForSeq2SeqLM` | `text2text-generation` |

docs/source/ipex/models.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Here is the list of the supported architectures :
4040
- Roberta
4141
- Roformer
4242
- SqueezeBert
43+
- T5
4344
- UniSpeech
4445
- Vit
4546
- Wav2Vec2

optimum/intel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
_import_structure["utils.dummy_ipex_objects"] = []
5555
_import_structure["ipex"] = [
5656
"IPEXModelForCausalLM",
57+
"IPEXModelForSeq2SeqLM",
5758
"IPEXModelForSequenceClassification",
5859
"IPEXModelForMaskedLM",
5960
"IPEXModelForTokenClassification",
@@ -248,6 +249,7 @@
248249
IPEXModelForImageClassification,
249250
IPEXModelForMaskedLM,
250251
IPEXModelForQuestionAnswering,
252+
IPEXModelForSeq2SeqLM,
251253
IPEXModelForSequenceClassification,
252254
IPEXModelForTokenClassification,
253255
)

optimum/intel/ipex/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
IPEXModelForImageClassification,
2121
IPEXModelForMaskedLM,
2222
IPEXModelForQuestionAnswering,
23+
IPEXModelForSeq2SeqLM,
2324
IPEXModelForSequenceClassification,
2425
IPEXModelForTokenClassification,
2526
)

optimum/intel/ipex/modeling_base.py

+123-29
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AutoModelForImageClassification,
3131
AutoModelForMaskedLM,
3232
AutoModelForQuestionAnswering,
33+
AutoModelForSeq2SeqLM,
3334
AutoModelForSequenceClassification,
3435
AutoModelForTokenClassification,
3536
GenerationConfig,
@@ -60,8 +61,8 @@
6061
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2")
6162
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
6263
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
63-
# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6
64-
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit")
64+
# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
65+
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2")
6566

6667

6768
def _is_patched_with_ipex(model, task, use_cache: bool = True):
@@ -84,15 +85,21 @@ def __init__(
8485
model,
8586
config: PretrainedConfig = None,
8687
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
88+
warmup: Optional[bool] = True,
8789
**kwargs,
8890
):
8991
config = config or model.config
9092
OptimizedModel.__init__(self, model=model, config=config)
9193

94+
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
95+
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
96+
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
97+
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
9298
self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32
9399
self.use_cache = kwargs.get("use_cache", False)
94100
self.model_save_dir = model_save_dir
95101
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
102+
self.compiled = False
96103

97104
self.input_names = set(inspect.signature(model.forward).parameters)
98105

@@ -104,25 +111,10 @@ def __init__(
104111
if hasattr(self.auto_model_class, "register"):
105112
self.auto_model_class.register(AutoConfig, self.__class__)
106113

107-
# Non-generation tasks can use torch.compile to get acceleration.
108-
if (
109-
model.device.type == "cpu"
110-
and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS
111-
and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES
112-
and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
113-
):
114-
from torch._inductor import config
115-
116-
# System level optimization
117-
torch._inductor.config.cpp_wrapper = True
118-
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
119-
logger.info("Enable torch.compile optimization, start warm up")
120-
self.model.forward = torch.compile(self.model.forward)
121-
inputs = prepare_jit_inputs(model, self.export_feature, False)
122-
with torch.no_grad():
123-
self.model(**inputs)
124-
self.model(**inputs)
125-
logger.info("Warm up end")
114+
self.maybe_apply_torch_compile()
115+
116+
if warmup:
117+
self._init_warmup()
126118

127119
@classmethod
128120
def _from_transformers(cls, *args, **kwargs):
@@ -192,6 +184,31 @@ def to(self, device: Union[torch.device, str]):
192184
def can_generate(self):
193185
return isinstance(self, GenerationMixin)
194186

187+
def maybe_apply_torch_compile(self):
188+
if (
189+
self.model.device.type != "cpu"
190+
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
191+
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
192+
):
193+
return
194+
if self.use_cache and not self._supports_static_cache:
195+
return
196+
from torch._inductor import config as inductor_config
197+
198+
# System level optimization
199+
inductor_config.cpp_wrapper = True
200+
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
201+
logger.info("Enable torch.compile optimization")
202+
self.model.forward = torch.compile(self.model.forward)
203+
self.compiled = True
204+
205+
def _init_warmup(self):
206+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
207+
with torch.no_grad():
208+
self.model(**inputs)
209+
self.model(**inputs)
210+
logger.info("Warm up end")
211+
195212

196213
class IPEXModelForSequenceClassification(IPEXModel):
197214
auto_model_class = AutoModelForSequenceClassification
@@ -236,16 +253,10 @@ def __init__(
236253
config: PretrainedConfig = None,
237254
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
238255
use_cache: bool = True,
256+
warmup: Optional[bool] = True,
239257
**kwargs,
240258
):
241-
super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache)
242-
243-
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
244-
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
245-
self._supports_cache_class = getattr(model, "_supports_cache_class", None)
246-
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
247-
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
248-
259+
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
249260
if self._add_patch:
250261
self._supports_cache_class = True
251262
GenerationMixin.__init__(self)
@@ -269,6 +280,9 @@ def __init__(
269280
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
270281
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
271282

283+
if warmup:
284+
self._init_warmup()
285+
272286
@torch.no_grad()
273287
def forward(
274288
self,
@@ -285,6 +299,9 @@ def _prepare_generation_config(
285299
) -> Tuple[GenerationConfig, Dict]:
286300
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
287301
generation_method = generation_config.get_generation_mode().value
302+
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
303+
# Use static cache for torch compile
304+
generation_config.cache_implementation = "static"
288305
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
289306
raise ValueError(
290307
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
@@ -337,6 +354,83 @@ def generate(self, *args, **kwargs):
337354

338355
return result
339356

357+
def _init_warmup(self):
358+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
359+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
360+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
361+
logger.info("Warm up end")
362+
363+
364+
class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
365+
auto_model_class = AutoModelForSeq2SeqLM
366+
export_feature = "text2text-generation"
367+
368+
def __init__(
369+
self,
370+
model,
371+
config: PretrainedConfig = None,
372+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
373+
use_cache: bool = True,
374+
warmup: Optional[bool] = True,
375+
**kwargs,
376+
):
377+
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache)
378+
GenerationMixin.__init__(self)
379+
380+
model_type = self.config.model_type.replace("_", "-")
381+
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config)
382+
383+
self.config.is_decoder = False
384+
self.config.is_encoder_decoder = True
385+
386+
self.generation_config = GenerationConfig.from_model_config(self.config)
387+
try:
388+
self.model_cls = get_class_from_dynamic_module(
389+
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
390+
)
391+
except AttributeError:
392+
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)
393+
394+
if hasattr(self.model_cls, "_convert_to_standard_cache"):
395+
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
396+
397+
if warmup:
398+
self._init_warmup()
399+
400+
@torch.no_grad()
401+
def forward(
402+
self,
403+
input_ids: torch.LongTensor = None,
404+
attention_mask: Optional[torch.FloatTensor] = None,
405+
**kwargs,
406+
) -> CausalLMOutputWithPast:
407+
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
408+
409+
def _prepare_generation_config(
410+
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
411+
) -> Tuple[GenerationConfig, Dict]:
412+
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
413+
# Use static cache for torch.compile
414+
if self.compiled:
415+
generation_config.cache_implementation = "static"
416+
417+
return generation_config, model_kwargs
418+
419+
def _reorder_cache(self, *args, **kwargs):
420+
return self.model._reorder_cache(*args, **kwargs)
421+
422+
def prepare_inputs_for_generation(self, *args, **kwargs):
423+
return self.model.prepare_inputs_for_generation(*args, **kwargs)
424+
425+
def get_encoder(self, *args, **kwargs):
426+
return self.model.get_encoder(*args, **kwargs)
427+
428+
def _init_warmup(self):
429+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
430+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
431+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
432+
logger.info("Warm up end")
433+
340434

341435
def _ipex_crop_past_key_values(model, past_key_values, max_length):
342436
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):

optimum/intel/ipex/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_HEAD_TO_AUTOMODELS = {
1717
"feature-extraction": "IPEXModel",
1818
"text-generation": "IPEXModelForCausalLM",
19+
"text2text-generation": "IPEXModelForSeq2SeqLM",
1920
"text-classification": "IPEXModelForSequenceClassification",
2021
"token-classification": "IPEXModelForTokenClassification",
2122
"question-answering": "IPEXModelForQuestionAnswering",

optimum/intel/pipelines/pipeline_base.py

+19
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
IPEXModelForImageClassification,
5959
IPEXModelForMaskedLM,
6060
IPEXModelForQuestionAnswering,
61+
IPEXModelForSeq2SeqLM,
6162
IPEXModelForSequenceClassification,
6263
IPEXModelForTokenClassification,
6364
)
@@ -69,6 +70,24 @@
6970
"default": "gpt2",
7071
"type": "text",
7172
},
73+
"summarization": {
74+
"impl": SummarizationPipeline,
75+
"class": (IPEXModelForSeq2SeqLM,),
76+
"default": "t5-base",
77+
"type": "text",
78+
},
79+
"translation": {
80+
"impl": TranslationPipeline,
81+
"class": (IPEXModelForSeq2SeqLM,),
82+
"default": "t5-small",
83+
"type": "text",
84+
},
85+
"text2text-generation": {
86+
"impl": Text2TextGenerationPipeline,
87+
"class": (IPEXModelForSeq2SeqLM,),
88+
"default": "t5-small",
89+
"type": "text",
90+
},
7291
"fill-mask": {
7392
"impl": FillMaskPipeline,
7493
"class": (IPEXModelForMaskedLM,),

optimum/intel/utils/dummy_ipex_objects.py

+11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs):
7070
requires_backends(cls, ["ipex"])
7171

7272

73+
class IPEXModelForSeq2SeqLM(metaclass=DummyObject):
74+
_backends = ["ipex"]
75+
76+
def __init__(self, *args, **kwargs):
77+
requires_backends(self, ["ipex"])
78+
79+
@classmethod
80+
def from_pretrained(cls, *args, **kwargs):
81+
requires_backends(cls, ["ipex"])
82+
83+
7384
class IPEXModelForQuestionAnswering(metaclass=DummyObject):
7485
_backends = ["ipex"]
7586

0 commit comments

Comments
 (0)