@@ -240,11 +240,14 @@ def maybe_apply_torch_compile(self):
240
240
self .compiled = True
241
241
242
242
def _init_warmup (self ):
243
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
244
- with torch .no_grad ():
245
- self .model (** inputs )
246
- self .model (** inputs )
247
- logger .info ("Warm up end" )
243
+ if self .compiled :
244
+ logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
245
+ else :
246
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
247
+ with torch .no_grad ():
248
+ self .model (** inputs )
249
+ self .model (** inputs )
250
+ logger .info ("Warm up end" )
248
251
249
252
250
253
class IPEXModelForSequenceClassification (IPEXModel ):
@@ -400,10 +403,13 @@ def generate(self, *args, **kwargs):
400
403
return result
401
404
402
405
def _init_warmup (self ):
403
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
404
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
405
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
406
- logger .info ("Warm up end" )
406
+ if self .compiled :
407
+ logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
408
+ else :
409
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
410
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
411
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
412
+ logger .info ("Warm up end" )
407
413
408
414
409
415
class IPEXModelForSeq2SeqLM (IPEXModel , GenerationMixin ):
@@ -478,10 +484,13 @@ def _supports_num_logits_to_keep(self) -> bool:
478
484
return "num_logits_to_keep" in set (inspect .signature (self .model .forward ).parameters .keys ())
479
485
480
486
def _init_warmup (self ):
481
- inputs = prepare_jit_inputs (self .model , self .export_feature , False )
482
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
483
- self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
484
- logger .info ("Warm up end" )
487
+ if self .compiled :
488
+ logger .info ("Detected torch.compile is applied, please warm-up by your own case" )
489
+ else :
490
+ inputs = prepare_jit_inputs (self .model , self .export_feature , False )
491
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
492
+ self .generate (input_ids = inputs ["input_ids" ], attention_mask = inputs ["attention_mask" ], max_new_tokens = 4 )
493
+ logger .info ("Warm up end" )
485
494
486
495
487
496
def _ipex_crop_past_key_values (model , past_key_values , max_length ):
0 commit comments