Skip to content

Commit 788e458

Browse files
authored
Add an initial warmup step to IPEXModels (#543)
* Handle autocast in IPEXModel.forward * Handle missing torch_dtype in config * Warmup IPEX models at init * Minor fix * Fix _init_warmup use_cache condition * Fix output handling in IPEX question answering
1 parent 8ee487d commit 788e458

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

optimum/intel/ipex/modeling_base.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import os
18+
from functools import wraps
1819
from pathlib import Path
1920
from tempfile import TemporaryDirectory
2021
from typing import Optional, Tuple, Union
@@ -45,7 +46,7 @@
4546
from optimum.modeling_base import OptimizedModel
4647
from optimum.utils import NormalizedConfigManager
4748

48-
from ..generation.modeling import jit_trace
49+
from ..generation.modeling import jit_trace, prepare_jit_inputs
4950
from ..utils.import_utils import is_torch_version
5051
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
5152

@@ -64,6 +65,7 @@ def __init__(
6465
model,
6566
config: PretrainedConfig = None,
6667
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
68+
warmup: bool = True,
6769
**kwargs,
6870
):
6971
OptimizedModel.__init__(self, model=model, config=config)
@@ -81,6 +83,8 @@ def __init__(
8183
AutoConfig.register(self.base_model_prefix, AutoConfig)
8284
if hasattr(self.auto_model_class, "register"):
8385
self.auto_model_class.register(AutoConfig, self.__class__)
86+
if warmup:
87+
self._init_warmup()
8488

8589
@classmethod
8690
def _from_transformers(
@@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs):
220224
out = self.model(*args, **kwargs)
221225
return out
222226

227+
def _init_warmup(self):
228+
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
229+
# the results of the compute are unpredictable
230+
use_cache = "past_key_values" in self.input_names
231+
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
232+
for _ in range(2):
233+
self(**dummy_inputs)
234+
223235

224236
class IPEXModelForSequenceClassification(IPEXModel):
225237
auto_model_class = AutoModelForSequenceClassification
@@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
278290
auto_model_class = AutoModelForQuestionAnswering
279291
export_feature = "question-answering"
280292

281-
def forward(self, *args, **kwargs):
282-
outputs = self._call_model(*args, **kwargs)
293+
def forward(self,
294+
input_ids: torch.Tensor,
295+
attention_mask: torch.Tensor,
296+
token_type_ids: torch.Tensor = None,
297+
**kwargs,
298+
):
299+
inputs = {
300+
"input_ids": input_ids,
301+
"attention_mask": attention_mask,
302+
}
303+
304+
if "token_type_ids" in self.input_names:
305+
inputs["token_type_ids"] = token_type_ids
306+
307+
outputs = self._call_model(**inputs)
283308
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
284309
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
285310
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -295,9 +320,11 @@ def __init__(
295320
config: PretrainedConfig = None,
296321
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
297322
use_cache: bool = True,
323+
warmup: bool = True,
298324
**kwargs,
299325
):
300-
super().__init__(model, config, model_save_dir=model_save_dir)
326+
# Perform the initial warmup at the end of __init__
327+
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
301328

302329
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
303330
self.model_dtype = kwargs.get("model_dtype", self.dtype)
@@ -325,6 +352,8 @@ def __init__(
325352
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
326353
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
327354
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
355+
if warmup:
356+
self._init_warmup()
328357

329358
def _prepare_past_key_values(self, input_ids):
330359
model_type = self.config.model_type.replace("_", "-")

0 commit comments

Comments
 (0)