Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor IPEX CausalLM for better model architecture scale #544

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions optimum/intel/ipex/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
IPEXModelForMaskedLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
IPEXBloomForCausalLM,
IPEXMPTForCausalLM,
IPEXOPTForCausalLM,
IPEXGPTBigCodeForCausalLM,
IPEXModelForQuestionAnswering,
_MODEL_TYPE_TO_AUTOMODELS,
)


Expand Down Expand Up @@ -139,13 +134,7 @@ def __enter__(self):
)
if task in _HEAD_TO_AUTOMODELS:
model = jit_trace(model, task, use_cache)
model_type = getattr(self._original.config, "model_type", "").replace("_", "-")

if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys():
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type]
else:
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])

auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
model = auto_model_class(model, self._original.config, use_cache=use_cache)

# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
Expand Down
244 changes: 15 additions & 229 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
GenerationMixin,
PretrainedConfig,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
Expand Down Expand Up @@ -164,12 +166,8 @@ def _from_pretrained(

model = torch.jit.load(model_cache_path)
torch.jit.freeze(model.eval())
model_type = config.model_type.replace("_", "-")
init_cls = cls
if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS:
init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type]

return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
Expand Down Expand Up @@ -302,6 +300,18 @@ def __init__(
config.is_decoder = True
config.is_encoder_decoder = False
self.generation_config = GenerationConfig.from_model_config(config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForCausalLM"], model_save_dir
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache

def _prepare_past_key_values(self, input_ids):
model_type = self.config.model_type.replace("_", "-")
Expand Down Expand Up @@ -378,227 +388,3 @@ def forward(
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
position_ids = kwargs.get("position_ids", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": position_ids,
"attention_mask": attention_mask,
}

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)


class IPEXGPTBigCodeForCausalLM(IPEXModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
# Omit tokens covered by past_key_values
if past_key_values:
if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None

model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
)
return model_inputs


class IPEXBloomForCausalLM(IPEXModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

# only last token for input_ids if past is not None
if past_key_values:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}

# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
standardized_past = IPEXModelForCausalLM._convert_to_standard_cache(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return IPEXModelForCausalLM._convert_to_bloom_cache(reordered_past)

@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]]
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)


class IPEXOPTForCausalLM(IPEXModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


class IPEXMPTForCausalLM(IPEXModelForCausalLM):
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}


_MODEL_TYPE_TO_AUTOMODELS = {
"bloom": IPEXBloomForCausalLM,
"mpt": IPEXMPTForCausalLM,
"opt": IPEXOPTForCausalLM,
"gpt-bigcode": IPEXGPTBigCodeForCausalLM,
}