From 310a3ac85dad49cfec9c12473b3b3ec77507fda2 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Tue, 30 Jan 2024 19:37:35 -0800 Subject: [PATCH 1/2] Refactor IPEX CausalLM for better model arch scale --- optimum/intel/ipex/inference.py | 13 +- optimum/intel/ipex/modeling_base.py | 242 ++-------------------------- 2 files changed, 14 insertions(+), 241 deletions(-) diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index 25145a6997..ccf2da9d80 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -31,12 +31,7 @@ IPEXModelForMaskedLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, - IPEXBloomForCausalLM, - IPEXMPTForCausalLM, - IPEXOPTForCausalLM, - IPEXGPTBigCodeForCausalLM, IPEXModelForQuestionAnswering, - _MODEL_TYPE_TO_AUTOMODELS, ) @@ -139,13 +134,7 @@ def __enter__(self): ) if task in _HEAD_TO_AUTOMODELS: model = jit_trace(model, task, use_cache) - model_type = getattr(self._original.config, "model_type", "").replace("_", "-") - - if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys(): - auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type] - else: - auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) - + auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) model = auto_model_class(model, self._original.config, use_cache=use_cache) # Enable automatic mixed precision (AMP) if we are going to target `bfloat16` diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a522fef265..5c4514cacc 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -38,6 +38,8 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.utils import WEIGHTS_NAME +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.models.auto.auto_factory import _get_model_class as get_model_class from optimum.exporters import TasksManager from optimum.modeling_base import OptimizedModel @@ -164,12 +166,8 @@ def _from_pretrained( model = torch.jit.load(model_cache_path) torch.jit.freeze(model.eval()) - model_type = config.model_type.replace("_", "-") - init_cls = cls - if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS: - init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type] - return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): output_path = os.path.join(save_directory, WEIGHTS_NAME) @@ -302,6 +300,16 @@ def __init__( config.is_decoder = True config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(config) + try: + self.model_cls = get_class_from_dynamic_module(self.config.auto_map['AutoModelForCausalLM'], model_save_dir) + except AttributeError: + self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) + self._reorder_cache = self.model_cls._reorder_cache.__get__(self) + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) + if hasattr(self.model_cls, '_convert_to_standard_cache'): + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache + if hasattr(self.model_cls, '_convert_to_bloom_cache'): + self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") @@ -378,227 +386,3 @@ def forward( past_key_values = outputs["past_key_values"] if self.use_cache else None return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - position_ids = kwargs.get("position_ids", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": position_ids, - "attention_mask": attention_mask, - } - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - - -class IPEXGPTBigCodeForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - # Omit tokens covered by past_key_values - if past_key_values: - if self.config.multi_query: - past_length = past_key_values[0].shape[1] - else: - past_length = past_key_values[0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - model_inputs = {"input_ids": input_ids} - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - } - ) - return model_inputs - - -class IPEXBloomForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - # only last token for input_ids if past is not None - if past_key_values: - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: - standardized_past = IPEXModelForCausalLM._convert_to_standard_cache(past, batch_size=len(beam_idx)) - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in standardized_past - ) - return IPEXModelForCausalLM._convert_to_bloom_cache(reordered_past) - - @staticmethod - def _convert_to_standard_cache( - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, - num_heads, ...])) - """ - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape - num_heads = batch_size_times_num_heads // batch_size - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - @staticmethod - def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]] - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: - """ - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) - """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - -class IPEXOPTForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - -class IPEXMPTForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - -_MODEL_TYPE_TO_AUTOMODELS = { - "bloom": IPEXBloomForCausalLM, - "mpt": IPEXMPTForCausalLM, - "opt": IPEXOPTForCausalLM, - "gpt-bigcode": IPEXGPTBigCodeForCausalLM, -} From 461e2c9011a077b04513962a8f38c14bd4a068a5 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 01:32:32 -0800 Subject: [PATCH 2/2] Fix style --- optimum/intel/ipex/modeling_base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 5c4514cacc..b79f720348 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -36,10 +36,10 @@ GenerationMixin, PretrainedConfig, ) -from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput -from transformers.utils import WEIGHTS_NAME from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.auto.auto_factory import _get_model_class as get_model_class +from transformers.utils import WEIGHTS_NAME from optimum.exporters import TasksManager from optimum.modeling_base import OptimizedModel @@ -301,14 +301,16 @@ def __init__( config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(config) try: - self.model_cls = get_class_from_dynamic_module(self.config.auto_map['AutoModelForCausalLM'], model_save_dir) + self.model_cls = get_class_from_dynamic_module( + self.config.auto_map["AutoModelForCausalLM"], model_save_dir + ) except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) self._reorder_cache = self.model_cls._reorder_cache.__get__(self) self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) - if hasattr(self.model_cls, '_convert_to_standard_cache'): + if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache - if hasattr(self.model_cls, '_convert_to_bloom_cache'): + if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache def _prepare_past_key_values(self, input_ids):