15
15
16
16
import logging
17
17
import os
18
+ from functools import wraps
18
19
from pathlib import Path
19
20
from tempfile import TemporaryDirectory
20
21
from typing import Optional , Tuple , Union
43
44
from optimum .modeling_base import OptimizedModel
44
45
from optimum .utils import NormalizedConfigManager
45
46
46
- from ..generation .modeling import jit_trace
47
+ from ..generation .modeling import jit_trace , prepare_jit_inputs
47
48
from ..utils .import_utils import is_torch_version
48
49
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , patch_decoder_attention_mask
49
50
@@ -62,6 +63,7 @@ def __init__(
62
63
model ,
63
64
config : PretrainedConfig = None ,
64
65
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
66
+ initial_warmup : bool = True ,
65
67
** kwargs ,
66
68
):
67
69
OptimizedModel .__init__ (self , model = model , config = config )
@@ -79,6 +81,8 @@ def __init__(
79
81
AutoConfig .register (self .base_model_prefix , AutoConfig )
80
82
if hasattr (self .auto_model_class , "register" ):
81
83
self .auto_model_class .register (AutoConfig , self .__class__ )
84
+ if initial_warmup :
85
+ self ._init_warmup ()
82
86
83
87
@classmethod
84
88
def _from_transformers (
@@ -222,6 +226,14 @@ def _call_model(self, *args, **kwargs):
222
226
out = self .model (* args , ** kwargs )
223
227
return out
224
228
229
+ def _init_warmup (self ):
230
+ # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
231
+ # the results of the compute are unpredictable
232
+ use_cache = getattr (self , "use_cache" , getattr (self .config , "use_cache" , False ))
233
+ dummy_inputs = prepare_jit_inputs (self , self .export_feature , use_cache )
234
+ for _ in range (2 ):
235
+ self (** dummy_inputs )
236
+
225
237
226
238
class IPEXModelForSequenceClassification (IPEXModel ):
227
239
auto_model_class = AutoModelForSequenceClassification
@@ -280,8 +292,9 @@ class IPEXModelForQuestionAnswering(IPEXModel):
280
292
auto_model_class = AutoModelForQuestionAnswering
281
293
export_feature = "question-answering"
282
294
295
+ @wraps (IPEXModel .forward )
283
296
def forward (self , * args , ** kwargs ):
284
- outputs = self . _call_model (* args , ** kwargs )
297
+ outputs = super (). forward (* args , ** kwargs )
285
298
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
286
299
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
287
300
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
@@ -299,7 +312,8 @@ def __init__(
299
312
use_cache : bool = True ,
300
313
** kwargs ,
301
314
):
302
- super ().__init__ (model , config , model_save_dir = model_save_dir )
315
+ # Perform the initial warmup at the end of __init__
316
+ super ().__init__ (model , config , model_save_dir = model_save_dir , initial_warmup = False )
303
317
304
318
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
305
319
self .model_dtype = kwargs .get ("model_dtype" , self .dtype )
@@ -315,6 +329,7 @@ def __init__(
315
329
config .is_decoder = True
316
330
config .is_encoder_decoder = False
317
331
self .generation_config = GenerationConfig .from_model_config (config )
332
+ self ._init_warmup ()
318
333
319
334
def _prepare_past_key_values (self , input_ids ):
320
335
model_type = self .config .model_type .replace ("_" , "-" )
0 commit comments