Skip to content

Commit 6a83c2e

Browse files
committed
Warmup IPEX models at init
1 parent f468914 commit 6a83c2e

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

optimum/intel/ipex/modeling_base.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import os
18+
from functools import wraps
1819
from pathlib import Path
1920
from tempfile import TemporaryDirectory
2021
from typing import Optional, Tuple, Union
@@ -43,7 +44,7 @@
4344
from optimum.modeling_base import OptimizedModel
4445
from optimum.utils import NormalizedConfigManager
4546

46-
from ..generation.modeling import jit_trace
47+
from ..generation.modeling import jit_trace, prepare_jit_inputs
4748
from ..utils.import_utils import is_torch_version
4849
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
4950

@@ -62,6 +63,7 @@ def __init__(
6263
model,
6364
config: PretrainedConfig = None,
6465
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
66+
initial_warmup: bool = True,
6567
**kwargs,
6668
):
6769
OptimizedModel.__init__(self, model=model, config=config)
@@ -79,6 +81,8 @@ def __init__(
7981
AutoConfig.register(self.base_model_prefix, AutoConfig)
8082
if hasattr(self.auto_model_class, "register"):
8183
self.auto_model_class.register(AutoConfig, self.__class__)
84+
if initial_warmup:
85+
self._init_warmup()
8286

8387
@classmethod
8488
def _from_transformers(
@@ -222,6 +226,14 @@ def _call_model(self, *args, **kwargs):
222226
out = self.model(*args, **kwargs)
223227
return out
224228

229+
def _init_warmup(self):
230+
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
231+
# the results of the compute are unpredictable
232+
use_cache = getattr(self, "use_cache", getattr(self.config, "use_cache", False))
233+
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
234+
for _ in range(2):
235+
self(**dummy_inputs)
236+
225237

226238
class IPEXModelForSequenceClassification(IPEXModel):
227239
auto_model_class = AutoModelForSequenceClassification
@@ -280,8 +292,9 @@ class IPEXModelForQuestionAnswering(IPEXModel):
280292
auto_model_class = AutoModelForQuestionAnswering
281293
export_feature = "question-answering"
282294

295+
@wraps(IPEXModel.forward)
283296
def forward(self, *args, **kwargs):
284-
outputs = self._call_model(*args, **kwargs)
297+
outputs = super().forward(*args, **kwargs)
285298
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
286299
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
287300
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -299,7 +312,8 @@ def __init__(
299312
use_cache: bool = True,
300313
**kwargs,
301314
):
302-
super().__init__(model, config, model_save_dir=model_save_dir)
315+
# Perform the initial warmup at the end of __init__
316+
super().__init__(model, config, model_save_dir=model_save_dir, initial_warmup=False)
303317

304318
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
305319
self.model_dtype = kwargs.get("model_dtype", self.dtype)
@@ -315,6 +329,7 @@ def __init__(
315329
config.is_decoder = True
316330
config.is_encoder_decoder = False
317331
self.generation_config = GenerationConfig.from_model_config(config)
332+
self._init_warmup()
318333

319334
def _prepare_past_key_values(self, input_ids):
320335
model_type = self.config.model_type.replace("_", "-")

0 commit comments

Comments
 (0)