Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an initial warmup step to IPEXModels #543

Merged
merged 7 commits into from
Jan 31, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -43,7 +44,7 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..generation.modeling import jit_trace
from ..generation.modeling import jit_trace, prepare_jit_inputs
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask

Expand All @@ -62,11 +63,13 @@ def __init__(
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: bool = True,
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
# To do: add XPU support
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model.to(self._device)
self.model_save_dir = model_save_dir

Expand All @@ -78,6 +81,8 @@ def __init__(
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
if warmup:
self._init_warmup()

@classmethod
def _from_transformers(
Expand Down Expand Up @@ -190,7 +195,7 @@ def forward(
if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

def eval(self):
Expand All @@ -201,6 +206,10 @@ def eval(self):
def device(self) -> torch.device:
return self._device

@property
def dtype(self) -> torch.dtype:
return self._dtype

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
Expand All @@ -209,6 +218,22 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
try:
with torch.autocast(self.device.type, self.dtype):
out = self.model(*args, **kwargs)
except RuntimeError:
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
use_cache = getattr(self, "use_cache", getattr(self.config, "use_cache", False))
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
for _ in range(2):
self(**dummy_inputs)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -238,7 +263,7 @@ def forward(
"pixel_values": pixel_values,
}

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


Expand All @@ -259,16 +284,17 @@ def forward(
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self.model(**inputs)
outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
export_feature = "question-answering"

@wraps(IPEXModel.forward)
def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
outputs = super().forward(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prepare_jit_inputs looks at the signature of the function and the wraps and super help avoid code copy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer we avoid as it will fail in case outputs is not a dict

return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not sure to see the link with prepare_jit_inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In _init_warmup we call prepare_jit_inputs which examines the passed model's forward signature to see which dummy inputs exists in the signature. If we don't use wraps we get the signature of
    self, *args, **kwargs
    
    instead of
    self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, **kwargs,
    
  2. outputs will always be a dict because this is the output of IPEXModel.forward, no?

Copy link
Collaborator

@echarlaix echarlaix Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. OK I understand, was thinking that prepare_jit_inputs was only used for the torchscript export but I see that it's also used in _init_warmup, thanks for the clarification
  2. here I'm talking about outputs https://github.com/huggingface/optimum-intel/blob/8ee487dc2ade5bd0023d1bbe0a0103d6af8821e0/optimum/intel/ipex/modeling_base.py#L192C9-L192C16

start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
Expand All @@ -284,12 +310,14 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: bool = True,
**kwargs,
):
super().__init__(model, config, model_save_dir=model_save_dir)
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand All @@ -302,6 +330,8 @@ def __init__(
config.is_decoder = True
config.is_encoder_decoder = False
self.generation_config = GenerationConfig.from_model_config(config)
if warmup:
self._init_warmup()

def _prepare_past_key_values(self, input_ids):
model_type = self.config.model_type.replace("_", "-")
Expand Down Expand Up @@ -367,7 +397,7 @@ def forward(
inputs["past_key_values"] = past_key_values

# 2. Model forward
outputs = self.model(**inputs)
outputs = self._call_model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
Expand Down