@@ -149,7 +149,7 @@ def __init__(
149
149
150
150
self .maybe_apply_torch_compile ()
151
151
152
- if warmup :
152
+ if warmup and not self . compiled :
153
153
self ._init_warmup ()
154
154
155
155
@classmethod
@@ -240,14 +240,11 @@ def maybe_apply_torch_compile(self):
240
240
self .compiled = True
241
241
242
242
def _init_warmup (self ):
243
- if self .compiled :
244
- logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
245
- else :
246
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
247
- with torch .no_grad ():
248
- self .model (** inputs )
249
- self .model (** inputs )
250
- logger .info ("Warm up end" )
243
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
244
+ with torch .no_grad ():
245
+ self .model (** inputs )
246
+ self .model (** inputs )
247
+ logger .info ("Warm up end" )
251
248
252
249
253
250
class IPEXModelForSequenceClassification (IPEXModel ):
@@ -320,7 +317,7 @@ def __init__(
320
317
if hasattr (self .model_cls , "_convert_to_bloom_cache" ):
321
318
self ._convert_to_bloom_cache = self .model_cls ._convert_to_bloom_cache
322
319
323
- if warmup :
320
+ if warmup and not self . compiled :
324
321
self ._init_warmup ()
325
322
326
323
@torch .no_grad ()
@@ -403,13 +400,10 @@ def generate(self, *args, **kwargs):
403
400
return result
404
401
405
402
def _init_warmup (self ):
406
- if self .compiled :
407
- logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
408
- else :
409
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
410
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
411
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
412
- logger .info ("Warm up end" )
403
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
404
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
405
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
406
+ logger .info ("Warm up end" )
413
407
414
408
415
409
class IPEXModelForSeq2SeqLM (IPEXModel , GenerationMixin ):
@@ -445,7 +439,7 @@ def __init__(
445
439
if hasattr (self .model_cls , "_convert_to_standard_cache" ):
446
440
self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
447
441
448
- if warmup :
442
+ if warmup and not self . compiled :
449
443
self ._init_warmup ()
450
444
451
445
@torch .no_grad ()
@@ -484,13 +478,10 @@ def _supports_num_logits_to_keep(self) -> bool:
484
478
return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
485
479
486
480
def _init_warmup (self ):
487
- if self .compiled :
488
- logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
489
- else :
490
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
491
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
492
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
493
- logger .info ("Warm up end" )
481
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
482
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
483
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
484
+ logger .info ("Warm up end" )
494
485
495
486
496
487
def _ipex_crop_past_key_values (model , past_key_values , max_length ):
0 commit comments