Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an initial warmup step to IPEXModels #543

Merged
merged 7 commits into from
Jan 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -45,7 +46,7 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..generation.modeling import jit_trace
from ..generation.modeling import jit_trace, prepare_jit_inputs
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask

Expand All @@ -64,6 +65,7 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: bool = True,
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
Expand All @@ -81,6 +83,8 @@ def __init__(
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
if warmup:
self._init_warmup()

@classmethod
def _from_transformers(
Expand Down Expand Up @@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs):
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
for _ in range(2):
self(**dummy_inputs)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
export_feature = "question-answering"

def forward(self, *args, **kwargs):
outputs = self._call_model(*args, **kwargs)
def forward(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
Expand All @@ -295,9 +320,11 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: bool = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir)
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
Expand Down Expand Up @@ -325,6 +352,8 @@ def __init__(
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
if warmup:
self._init_warmup()

def _prepare_past_key_values(self, input_ids):
model_type = self.config.model_type.replace("_", "-")
Expand Down
Loading