Skip to content

Commit 40de842

Browse files
committed
warm up do not work for compiled model
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 6cceb30 commit 40de842

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

optimum/intel/ipex/modeling_base.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,14 @@ def maybe_apply_torch_compile(self):
240240
self.compiled = True
241241

242242
def _init_warmup(self):
243-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
244-
with torch.no_grad():
245-
self.model(**inputs)
246-
self.model(**inputs)
247-
logger.info("Warm up end")
243+
if self.compiled:
244+
logger.info("Detected torch.compile is applied, please warm-up by your own case")
245+
else:
246+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
247+
with torch.no_grad():
248+
self.model(**inputs)
249+
self.model(**inputs)
250+
logger.info("Warm up end")
248251

249252

250253
class IPEXModelForSequenceClassification(IPEXModel):
@@ -400,10 +403,13 @@ def generate(self, *args, **kwargs):
400403
return result
401404

402405
def _init_warmup(self):
403-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
404-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
405-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
406-
logger.info("Warm up end")
406+
if self.compiled:
407+
logger.info("Detected torch.compile is applied, please warm-up by your own case")
408+
else:
409+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
410+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
411+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
412+
logger.info("Warm up end")
407413

408414

409415
class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
@@ -478,10 +484,13 @@ def _supports_num_logits_to_keep(self) -> bool:
478484
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
479485

480486
def _init_warmup(self):
481-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
482-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
483-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
484-
logger.info("Warm up end")
487+
if self.compiled:
488+
logger.info("Detected torch.compile is applied, please warm-up by your own case")
489+
else:
490+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
491+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
492+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
493+
logger.info("Warm up end")
485494

486495

487496
def _ipex_crop_past_key_values(model, past_key_values, max_length):

0 commit comments

Comments
 (0)