30
30
AutoModelForImageClassification ,
31
31
AutoModelForMaskedLM ,
32
32
AutoModelForQuestionAnswering ,
33
+ AutoModelForSeq2SeqLM ,
33
34
AutoModelForSequenceClassification ,
34
35
AutoModelForTokenClassification ,
35
36
GenerationConfig ,
60
61
_IPEX_SUPPORT_MODEL_TYPES = ("llama" , "bert" , "vit" , "falcon" , "gpt2" )
61
62
_IPEX_EXPORTED_GENERATION_METHODS = ("sample" , "greedy_search" , "beam_sample" , "beam_search" , "assisted_generation" )
62
63
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
63
- # TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6
64
- _COMPILE_NOT_READY_MODEL_TYPES = ("electra" , "roformer" , "beit" )
64
+ # TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6
65
+ _COMPILE_NOT_READY_MODEL_TYPES = ("electra" , "roformer" , "gpt_neox" , " beit" , "llama" , "falcon" , "gpt2 " )
65
66
66
67
67
68
def _is_patched_with_ipex (model , task , use_cache : bool = True ):
@@ -84,15 +85,21 @@ def __init__(
84
85
model ,
85
86
config : PretrainedConfig = None ,
86
87
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
88
+ warmup : Optional [bool ] = True ,
87
89
** kwargs ,
88
90
):
89
91
config = config or model .config
90
92
OptimizedModel .__init__ (self , model = model , config = config )
91
93
94
+ self ._supports_cache_class = getattr (model , "_supports_cache_class" , None )
95
+ self ._supports_sdpa = getattr (model , "_supports_sdpa" , None )
96
+ self ._supports_quantized_cache = getattr (model , "_supports_quantized_cache" , None )
97
+ self ._supports_static_cache = getattr (model , "_supports_static_cache" , None )
92
98
self ._dtype = self .model .dtype if self .model .dtype is not None else torch .float32
93
99
self .use_cache = kwargs .get ("use_cache" , False )
94
100
self .model_save_dir = model_save_dir
95
101
self ._add_patch = _is_patched_with_ipex (model , self .export_feature , self .use_cache )
102
+ self .compiled = False
96
103
97
104
self .input_names = set (inspect .signature (model .forward ).parameters )
98
105
@@ -104,25 +111,10 @@ def __init__(
104
111
if hasattr (self .auto_model_class , "register" ):
105
112
self .auto_model_class .register (AutoConfig , self .__class__ )
106
113
107
- # Non-generation tasks can use torch.compile to get acceleration.
108
- if (
109
- model .device .type == "cpu"
110
- and self .export_feature not in _IPEX_EXPORTED_GENERATION_TASKS
111
- and config .model_type not in _COMPILE_NOT_READY_MODEL_TYPES
112
- and is_ipex_version (">=" , _IPEX_MINIMUM_VERSION_FOR_COMPILE )
113
- ):
114
- from torch ._inductor import config
115
-
116
- # System level optimization
117
- torch ._inductor .config .cpp_wrapper = True
118
- os .environ ["TORCHINDUCTOR_FREEZING" ] = "1"
119
- logger .info ("Enable torch.compile optimization, start warm up" )
120
- self .model .forward = torch .compile (self .model .forward )
121
- inputs = prepare_jit_inputs (model , self .export_feature , False )
122
- with torch .no_grad ():
123
- self .model (** inputs )
124
- self .model (** inputs )
125
- logger .info ("Warm up end" )
114
+ self .maybe_apply_torch_compile ()
115
+
116
+ if warmup :
117
+ self ._init_warmup ()
126
118
127
119
@classmethod
128
120
def _from_transformers (cls , * args , ** kwargs ):
@@ -192,6 +184,31 @@ def to(self, device: Union[torch.device, str]):
192
184
def can_generate (self ):
193
185
return isinstance (self , GenerationMixin )
194
186
187
+ def maybe_apply_torch_compile (self ):
188
+ if (
189
+ self .model .device .type != "cpu"
190
+ or self .config .model_type in _COMPILE_NOT_READY_MODEL_TYPES
191
+ or is_ipex_version ("<" , _IPEX_MINIMUM_VERSION_FOR_COMPILE )
192
+ ):
193
+ return
194
+ if self .use_cache and not self ._supports_static_cache :
195
+ return
196
+ from torch ._inductor import config as inductor_config
197
+
198
+ # System level optimization
199
+ inductor_config .cpp_wrapper = True
200
+ os .environ ["TORCHINDUCTOR_FREEZING" ] = "1"
201
+ logger .info ("Enable torch.compile optimization" )
202
+ self .model .forward = torch .compile (self .model .forward )
203
+ self .compiled = True
204
+
205
+ def _init_warmup (self ):
206
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
207
+ with torch .no_grad ():
208
+ self .model (** inputs )
209
+ self .model (** inputs )
210
+ logger .info ("Warm up end" )
211
+
195
212
196
213
class IPEXModelForSequenceClassification (IPEXModel ):
197
214
auto_model_class = AutoModelForSequenceClassification
@@ -236,16 +253,10 @@ def __init__(
236
253
config : PretrainedConfig = None ,
237
254
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
238
255
use_cache : bool = True ,
256
+ warmup : Optional [bool ] = True ,
239
257
** kwargs ,
240
258
):
241
- super ().__init__ (model , config , model_save_dir = model_save_dir , use_cache = use_cache )
242
-
243
- self ._supports_cache_class = getattr (model , "_supports_cache_class" , None )
244
- self ._supports_sdpa = getattr (model , "_supports_sdpa" , None )
245
- self ._supports_cache_class = getattr (model , "_supports_cache_class" , None )
246
- self ._supports_quantized_cache = getattr (model , "_supports_quantized_cache" , None )
247
- self ._supports_static_cache = getattr (model , "_supports_static_cache" , None )
248
-
259
+ super ().__init__ (model , config , model_save_dir = model_save_dir , warmup = False , use_cache = use_cache )
249
260
if self ._add_patch :
250
261
self ._supports_cache_class = True
251
262
GenerationMixin .__init__ (self )
@@ -269,6 +280,9 @@ def __init__(
269
280
if hasattr (self .model_cls , "_convert_to_bloom_cache" ):
270
281
self ._convert_to_bloom_cache = self .model_cls ._convert_to_bloom_cache
271
282
283
+ if warmup :
284
+ self ._init_warmup ()
285
+
272
286
@torch .no_grad ()
273
287
def forward (
274
288
self ,
@@ -285,6 +299,9 @@ def _prepare_generation_config(
285
299
) -> Tuple [GenerationConfig , Dict ]:
286
300
generation_config , model_kwargs = super ()._prepare_generation_config (generation_config , ** kwargs )
287
301
generation_method = generation_config .get_generation_mode ().value
302
+ if self .compiled and generation_config .cache_implementation != "ipex_paged" and self ._supports_static_cache :
303
+ # Use static cache for torch compile
304
+ generation_config .cache_implementation = "static"
288
305
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS :
289
306
raise ValueError (
290
307
f"The generation method { generation_method } is not supported for IPEXModelForCausalLM for now, support methods are { _IPEX_EXPORTED_GENERATION_METHODS } "
@@ -337,6 +354,83 @@ def generate(self, *args, **kwargs):
337
354
338
355
return result
339
356
357
+ def _init_warmup (self ):
358
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
359
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
360
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
361
+ logger .info ("Warm up end" )
362
+
363
+
364
+ class IPEXModelForSeq2SeqLM (IPEXModel , GenerationMixin ):
365
+ auto_model_class = AutoModelForSeq2SeqLM
366
+ export_feature = "text2text-generation"
367
+
368
+ def __init__ (
369
+ self ,
370
+ model ,
371
+ config : PretrainedConfig = None ,
372
+ model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
373
+ use_cache : bool = True ,
374
+ warmup : Optional [bool ] = True ,
375
+ ** kwargs ,
376
+ ):
377
+ super ().__init__ (model , config , model_save_dir = model_save_dir , warmup = False , use_cache = use_cache )
378
+ GenerationMixin .__init__ (self )
379
+
380
+ model_type = self .config .model_type .replace ("_" , "-" )
381
+ self .normalized_config = NormalizedConfigManager .get_normalized_config_class (model_type )(self .config )
382
+
383
+ self .config .is_decoder = False
384
+ self .config .is_encoder_decoder = True
385
+
386
+ self .generation_config = GenerationConfig .from_model_config (self .config )
387
+ try :
388
+ self .model_cls = get_class_from_dynamic_module (
389
+ self .config .auto_map ["AutoModelForSeq2SeqLM" ], model_save_dir
390
+ )
391
+ except AttributeError :
392
+ self .model_cls = get_model_class (self .config , AutoModelForSeq2SeqLM ._model_mapping )
393
+
394
+ if hasattr (self .model_cls , "_convert_to_standard_cache" ):
395
+ self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
396
+
397
+ if warmup :
398
+ self ._init_warmup ()
399
+
400
+ @torch .no_grad ()
401
+ def forward (
402
+ self ,
403
+ input_ids : torch .LongTensor = None ,
404
+ attention_mask : Optional [torch .FloatTensor ] = None ,
405
+ ** kwargs ,
406
+ ) -> CausalLMOutputWithPast :
407
+ return self .model (input_ids = input_ids , attention_mask = attention_mask , ** kwargs )
408
+
409
+ def _prepare_generation_config (
410
+ self , generation_config : Optional [GenerationConfig ], ** kwargs : Dict
411
+ ) -> Tuple [GenerationConfig , Dict ]:
412
+ generation_config , model_kwargs = super ()._prepare_generation_config (generation_config , ** kwargs )
413
+ # Use static cache for torch.compile
414
+ if self .compiled :
415
+ generation_config .cache_implementation = "static"
416
+
417
+ return generation_config , model_kwargs
418
+
419
+ def _reorder_cache (self , * args , ** kwargs ):
420
+ return self .model ._reorder_cache (* args , ** kwargs )
421
+
422
+ def prepare_inputs_for_generation (self , * args , ** kwargs ):
423
+ return self .model .prepare_inputs_for_generation (* args , ** kwargs )
424
+
425
+ def get_encoder (self , * args , ** kwargs ):
426
+ return self .model .get_encoder (* args , ** kwargs )
427
+
428
+ def _init_warmup (self ):
429
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
430
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
431
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
432
+ logger .info ("Warm up end" )
433
+
340
434
341
435
def _ipex_crop_past_key_values (model , past_key_values , max_length ):
342
436
if isinstance (model , IPEXModel ) and _is_patched_with_ipex (model , "text-generation" ):
0 commit comments