Skip to content

Commit dfcca7d

Browse files
committed
fix warmup
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 40de842 commit dfcca7d

File tree

1 file changed

+16
-25
lines changed

1 file changed

+16
-25
lines changed

optimum/intel/ipex/modeling_base.py

+16-25
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149

150150
self.maybe_apply_torch_compile()
151151

152-
if warmup:
152+
if warmup and not self.compiled:
153153
self._init_warmup()
154154

155155
@classmethod
@@ -240,14 +240,11 @@ def maybe_apply_torch_compile(self):
240240
self.compiled = True
241241

242242
def _init_warmup(self):
243-
if self.compiled:
244-
logger.info("Detected torch.compile is applied, please warm-up by your own case")
245-
else:
246-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
247-
with torch.no_grad():
248-
self.model(**inputs)
249-
self.model(**inputs)
250-
logger.info("Warm up end")
243+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
244+
with torch.no_grad():
245+
self.model(**inputs)
246+
self.model(**inputs)
247+
logger.info("Warm up end")
251248

252249

253250
class IPEXModelForSequenceClassification(IPEXModel):
@@ -320,7 +317,7 @@ def __init__(
320317
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
321318
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
322319

323-
if warmup:
320+
if warmup and not self.compiled:
324321
self._init_warmup()
325322

326323
@torch.no_grad()
@@ -403,13 +400,10 @@ def generate(self, *args, **kwargs):
403400
return result
404401

405402
def _init_warmup(self):
406-
if self.compiled:
407-
logger.info("Detected torch.compile is applied, please warm-up by your own case")
408-
else:
409-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
410-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
411-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
412-
logger.info("Warm up end")
403+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
404+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
405+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
406+
logger.info("Warm up end")
413407

414408

415409
class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin):
@@ -445,7 +439,7 @@ def __init__(
445439
if hasattr(self.model_cls, "_convert_to_standard_cache"):
446440
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
447441

448-
if warmup:
442+
if warmup and not self.compiled:
449443
self._init_warmup()
450444

451445
@torch.no_grad()
@@ -484,13 +478,10 @@ def _supports_num_logits_to_keep(self) -> bool:
484478
return "num_logits_to_keep" in set(inspect.signature(self.model.forward).parameters.keys())
485479

486480
def _init_warmup(self):
487-
if self.compiled:
488-
logger.info("Detected torch.compile is applied, please warm-up by your own case")
489-
else:
490-
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
491-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
492-
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
493-
logger.info("Warm up end")
481+
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
482+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
483+
self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4)
484+
logger.info("Warm up end")
494485

495486

496487
def _ipex_crop_past_key_values(model, past_key_values, max_length):

0 commit comments

Comments
 (0)