From 47511644b8dee09e4294c60a9257242a552588b6 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 30 Jan 2024 09:07:04 -0800 Subject: [PATCH 1/6] Handle autocast in IPEXModel.forward --- optimum/intel/ipex/modeling_base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a522fef265..10e4a78058 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -67,6 +67,7 @@ def __init__( OptimizedModel.__init__(self, model=model, config=config) # To do: add XPU support self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self._dtype = self.config.torch_dtype self.model.to(self._device) self.model_save_dir = model_save_dir @@ -190,7 +191,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -201,6 +202,10 @@ def eval(self): def device(self) -> torch.device: return self._device + @property + def dtype(self) -> torch.dtype: + return self._dtype + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -209,6 +214,14 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def _call_model(self, *args, **kwargs): + try: + with torch.autocast(self.device.type, self.dtype): + out = self.model(*args, **kwargs) + except RuntimeError: + out = self.model(*args, **kwargs) + return out + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -238,7 +251,7 @@ def forward( "pixel_values": pixel_values, } - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -259,7 +272,7 @@ def forward( if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -268,7 +281,7 @@ class IPEXModelForQuestionAnswering(IPEXModel): export_feature = "question-answering" def forward(self, *args, **kwargs): - outputs = self.model(*args, **kwargs) + outputs = self._call_model(*args, **kwargs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -289,7 +302,7 @@ def __init__( super().__init__(model, config, model_save_dir=model_save_dir) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.model_dtype = kwargs.get("model_dtype", None) + self.model_dtype = kwargs.get("model_dtype", self.dtype) self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: @@ -367,7 +380,7 @@ def forward( inputs["past_key_values"] = past_key_values # 2. Model forward - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) # 3. Process model outputs if isinstance(outputs, (list, tuple)): From 0edb5c41700780711ecec70427d655f07c5bae4c Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 30 Jan 2024 11:00:17 -0800 Subject: [PATCH 2/6] Handle missing torch_dtype in config --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 10e4a78058..49b8a986a1 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -67,7 +67,7 @@ def __init__( OptimizedModel.__init__(self, model=model, config=config) # To do: add XPU support self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self._dtype = self.config.torch_dtype + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model.to(self._device) self.model_save_dir = model_save_dir From 742ff39baafa6d1d14126dd11a64ca6b10309cc9 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 30 Jan 2024 18:51:22 -0800 Subject: [PATCH 3/6] Warmup IPEX models at init --- optimum/intel/ipex/modeling_base.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 49b8a986a1..b61e984d0a 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -15,6 +15,7 @@ import logging import os +from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -43,7 +44,7 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ..generation.modeling import jit_trace +from ..generation.modeling import jit_trace, prepare_jit_inputs from ..utils.import_utils import is_torch_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask @@ -62,6 +63,7 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + initial_warmup: bool = True, **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) @@ -79,6 +81,8 @@ def __init__( AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) + if initial_warmup: + self._init_warmup() @classmethod def _from_transformers( @@ -222,6 +226,14 @@ def _call_model(self, *args, **kwargs): out = self.model(*args, **kwargs) return out + def _init_warmup(self): + # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and + # the results of the compute are unpredictable + use_cache = getattr(self, "use_cache", getattr(self.config, "use_cache", False)) + dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + for _ in range(2): + self(**dummy_inputs) + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -280,8 +292,9 @@ class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" + @wraps(IPEXModel.forward) def forward(self, *args, **kwargs): - outputs = self._call_model(*args, **kwargs) + outputs = super().forward(*args, **kwargs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -299,7 +312,8 @@ def __init__( use_cache: bool = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir) + # Perform the initial warmup at the end of __init__ + super().__init__(model, config, model_save_dir=model_save_dir, initial_warmup=False) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) @@ -315,6 +329,7 @@ def __init__( config.is_decoder = True config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(config) + self._init_warmup() def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") From 10127707a958dcdb72528c0b0276679272539587 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 01:39:16 -0800 Subject: [PATCH 4/6] Minor fix --- optimum/intel/ipex/modeling_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index b61e984d0a..05adad3452 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -63,7 +63,7 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - initial_warmup: bool = True, + warmup: bool = True, **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) @@ -81,7 +81,7 @@ def __init__( AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - if initial_warmup: + if warmup: self._init_warmup() @classmethod @@ -310,10 +310,11 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: bool = True, **kwargs, ): # Perform the initial warmup at the end of __init__ - super().__init__(model, config, model_save_dir=model_save_dir, initial_warmup=False) + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) @@ -329,7 +330,8 @@ def __init__( config.is_decoder = True config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(config) - self._init_warmup() + if warmup: + self._init_warmup() def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") From d797cc9bbfa7939fffefafc565b1129ff779c687 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 08:04:34 -0800 Subject: [PATCH 5/6] Fix _init_warmup use_cache condition --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a596f6d3b4..f15d610c8b 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -227,7 +227,7 @@ def _call_model(self, *args, **kwargs): def _init_warmup(self): # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and # the results of the compute are unpredictable - use_cache = getattr(self, "use_cache", getattr(self.config, "use_cache", False)) + use_cache = "past_key_values" in self.input_names dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) for _ in range(2): self(**dummy_inputs) From abb7b0001974195e54b1fa81e0d7df3b28b2333e Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 08:30:10 -0800 Subject: [PATCH 6/6] Fix output handling in IPEX question answering --- optimum/intel/ipex/modeling_base.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index f15d610c8b..2f7267c984 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -290,9 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - @wraps(IPEXModel.forward) - def forward(self, *args, **kwargs): - outputs = super().forward(*args, **kwargs) + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + **kwargs, + ): + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids + + outputs = self._call_model(**inputs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits)