15
15
16
16
import logging
17
17
import os
18
+ from functools import wraps
18
19
from pathlib import Path
19
20
from tempfile import TemporaryDirectory
20
21
from typing import Optional , Tuple , Union
45
46
from optimum .modeling_base import OptimizedModel
46
47
from optimum .utils import NormalizedConfigManager
47
48
48
- from ..generation .modeling import jit_trace
49
+ from ..generation .modeling import jit_trace , prepare_jit_inputs
49
50
from ..utils .import_utils import is_torch_version
50
51
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , patch_decoder_attention_mask
51
52
@@ -64,6 +65,7 @@ def __init__(
64
65
model ,
65
66
config : PretrainedConfig = None ,
66
67
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
68
+ warmup : bool = True ,
67
69
** kwargs ,
68
70
):
69
71
OptimizedModel .__init__ (self , model = model , config = config )
@@ -81,6 +83,8 @@ def __init__(
81
83
AutoConfig .register (self .base_model_prefix , AutoConfig )
82
84
if hasattr (self .auto_model_class , "register" ):
83
85
self .auto_model_class .register (AutoConfig , self .__class__ )
86
+ if warmup :
87
+ self ._init_warmup ()
84
88
85
89
@classmethod
86
90
def _from_transformers (
@@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs):
220
224
out = self .model (* args , ** kwargs )
221
225
return out
222
226
227
+ def _init_warmup (self ):
228
+ # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
229
+ # the results of the compute are unpredictable
230
+ use_cache = "past_key_values" in self .input_names
231
+ dummy_inputs = prepare_jit_inputs (self , self .export_feature , use_cache )
232
+ for _ in range (2 ):
233
+ self (** dummy_inputs )
234
+
223
235
224
236
class IPEXModelForSequenceClassification (IPEXModel ):
225
237
auto_model_class = AutoModelForSequenceClassification
@@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
278
290
auto_model_class = AutoModelForQuestionAnswering
279
291
export_feature = "question-answering"
280
292
281
- def forward (self , * args , ** kwargs ):
282
- outputs = self ._call_model (* args , ** kwargs )
293
+ def forward (self ,
294
+ input_ids : torch .Tensor ,
295
+ attention_mask : torch .Tensor ,
296
+ token_type_ids : torch .Tensor = None ,
297
+ ** kwargs ,
298
+ ):
299
+ inputs = {
300
+ "input_ids" : input_ids ,
301
+ "attention_mask" : attention_mask ,
302
+ }
303
+
304
+ if "token_type_ids" in self .input_names :
305
+ inputs ["token_type_ids" ] = token_type_ids
306
+
307
+ outputs = self ._call_model (** inputs )
283
308
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
284
309
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
285
310
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
@@ -295,9 +320,11 @@ def __init__(
295
320
config : PretrainedConfig = None ,
296
321
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
297
322
use_cache : bool = True ,
323
+ warmup : bool = True ,
298
324
** kwargs ,
299
325
):
300
- super ().__init__ (model , config , model_save_dir = model_save_dir )
326
+ # Perform the initial warmup at the end of __init__
327
+ super ().__init__ (model , config , model_save_dir = model_save_dir , warmup = False )
301
328
302
329
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
303
330
self .model_dtype = kwargs .get ("model_dtype" , self .dtype )
@@ -325,6 +352,8 @@ def __init__(
325
352
self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
326
353
if hasattr (self .model_cls , "_convert_to_bloom_cache" ):
327
354
self ._convert_to_bloom_cache = self .model_cls ._convert_to_bloom_cache
355
+ if warmup :
356
+ self ._init_warmup ()
328
357
329
358
def _prepare_past_key_values (self , input_ids ):
330
359
model_type = self .config .model_type .replace ("_" , "-" )
0 commit comments