Skip to content

Commit e3e3acf

Browse files
committed
Refactor IPEX CausalLM for better model arch scale
1 parent 3b627f4 commit e3e3acf

File tree

2 files changed

+14
-241
lines changed

2 files changed

+14
-241
lines changed

optimum/intel/ipex/inference.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@
3131
IPEXModelForMaskedLM,
3232
IPEXModelForSequenceClassification,
3333
IPEXModelForTokenClassification,
34-
IPEXBloomForCausalLM,
35-
IPEXMPTForCausalLM,
36-
IPEXOPTForCausalLM,
37-
IPEXGPTBigCodeForCausalLM,
3834
IPEXModelForQuestionAnswering,
39-
_MODEL_TYPE_TO_AUTOMODELS,
4035
)
4136

4237

@@ -139,13 +134,7 @@ def __enter__(self):
139134
)
140135
if task in _HEAD_TO_AUTOMODELS:
141136
model = jit_trace(model, task, use_cache)
142-
model_type = getattr(self._original.config, "model_type", "").replace("_", "-")
143-
144-
if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys():
145-
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type]
146-
else:
147-
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
148-
137+
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
149138
model = auto_model_class(model, self._original.config, use_cache=use_cache)
150139

151140
# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`

optimum/intel/ipex/modeling_base.py

+13-229
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
)
3939
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
4040
from transformers.utils import WEIGHTS_NAME
41+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
42+
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
4143

4244
from optimum.exporters import TasksManager
4345
from optimum.modeling_base import OptimizedModel
@@ -164,12 +166,8 @@ def _from_pretrained(
164166

165167
model = torch.jit.load(model_cache_path)
166168
torch.jit.freeze(model.eval())
167-
model_type = config.model_type.replace("_", "-")
168-
init_cls = cls
169-
if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS:
170-
init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type]
171169

172-
return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
170+
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
173171

174172
def _save_pretrained(self, save_directory: Union[str, Path]):
175173
output_path = os.path.join(save_directory, WEIGHTS_NAME)
@@ -302,6 +300,16 @@ def __init__(
302300
config.is_decoder = True
303301
config.is_encoder_decoder = False
304302
self.generation_config = GenerationConfig.from_model_config(config)
303+
try:
304+
self.model_cls = get_class_from_dynamic_module(self.config.auto_map['AutoModelForCausalLM'], model_save_dir)
305+
except AttributeError:
306+
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
307+
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
308+
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
309+
if hasattr(self.model_cls, '_convert_to_standard_cache'):
310+
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
311+
if hasattr(self.model_cls, '_convert_to_bloom_cache'):
312+
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
305313

306314
def _prepare_past_key_values(self, input_ids):
307315
model_type = self.config.model_type.replace("_", "-")
@@ -378,227 +386,3 @@ def forward(
378386
past_key_values = outputs["past_key_values"] if self.use_cache else None
379387

380388
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
381-
382-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
383-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
384-
if past_key_values is not None:
385-
past_length = past_key_values[0][0].shape[2]
386-
# Some generation methods already pass only the last input ID
387-
if input_ids.shape[1] > past_length:
388-
remove_prefix_length = past_length
389-
else:
390-
# Default to old behavior: keep only final ID
391-
remove_prefix_length = input_ids.shape[1] - 1
392-
input_ids = input_ids[:, remove_prefix_length:]
393-
394-
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
395-
attention_mask = kwargs.get("attention_mask", None)
396-
use_cache = kwargs.get("use_cache", None)
397-
position_ids = kwargs.get("position_ids", None)
398-
399-
return {
400-
"input_ids": input_ids,
401-
"past_key_values": past_key_values,
402-
"use_cache": use_cache,
403-
"position_ids": position_ids,
404-
"attention_mask": attention_mask,
405-
}
406-
407-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
408-
@staticmethod
409-
def _reorder_cache(
410-
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
411-
) -> Tuple[Tuple[torch.Tensor]]:
412-
return tuple(
413-
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
414-
for layer_past in past_key_values
415-
)
416-
417-
418-
class IPEXGPTBigCodeForCausalLM(IPEXModelForCausalLM):
419-
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
420-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
421-
# Omit tokens covered by past_key_values
422-
if past_key_values:
423-
if self.config.multi_query:
424-
past_length = past_key_values[0].shape[1]
425-
else:
426-
past_length = past_key_values[0].shape[2]
427-
428-
# Some generation methods already pass only the last input ID
429-
if input_ids.shape[1] > past_length:
430-
remove_prefix_length = past_length
431-
else:
432-
# Default to old behavior: keep only final ID
433-
remove_prefix_length = input_ids.shape[1] - 1
434-
435-
input_ids = input_ids[:, remove_prefix_length:]
436-
437-
attention_mask = kwargs.get("attention_mask", None)
438-
position_ids = kwargs.get("position_ids", None)
439-
440-
if attention_mask is not None and position_ids is None:
441-
# create position_ids on the fly for batch generation
442-
position_ids = attention_mask.long().cumsum(-1) - 1
443-
position_ids.masked_fill_(attention_mask == 0, 1)
444-
if past_key_values:
445-
position_ids = position_ids[:, -input_ids.shape[1] :]
446-
else:
447-
position_ids = None
448-
449-
model_inputs = {"input_ids": input_ids}
450-
model_inputs.update(
451-
{
452-
"past_key_values": past_key_values,
453-
"use_cache": kwargs.get("use_cache"),
454-
"position_ids": position_ids,
455-
"attention_mask": attention_mask,
456-
}
457-
)
458-
return model_inputs
459-
460-
461-
class IPEXBloomForCausalLM(IPEXModelForCausalLM):
462-
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
463-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
464-
if past_key_values is not None:
465-
past_length = past_key_values[0][0].shape[2]
466-
# Some generation methods already pass only the last input ID
467-
if input_ids.shape[1] > past_length:
468-
remove_prefix_length = past_length
469-
else:
470-
# Default to old behavior: keep only final ID
471-
remove_prefix_length = input_ids.shape[1] - 1
472-
input_ids = input_ids[:, remove_prefix_length:]
473-
474-
attention_mask = kwargs.get("attention_mask", None)
475-
use_cache = kwargs.get("use_cache", None)
476-
477-
# only last token for input_ids if past is not None
478-
if past_key_values:
479-
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
480-
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
481-
past_key_values = self._convert_to_bloom_cache(past_key_values)
482-
483-
return {
484-
"input_ids": input_ids,
485-
"past_key_values": past_key_values,
486-
"use_cache": use_cache,
487-
"position_ids": None,
488-
"attention_mask": attention_mask,
489-
}
490-
491-
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
492-
@staticmethod
493-
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
494-
standardized_past = IPEXModelForCausalLM._convert_to_standard_cache(past, batch_size=len(beam_idx))
495-
496-
# Get a copy of `beam_idx` on all the devices where we need those indices.
497-
device_to_beam_idx = {
498-
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
499-
}
500-
reordered_past = tuple(
501-
(
502-
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
503-
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
504-
)
505-
for layer_past in standardized_past
506-
)
507-
return IPEXModelForCausalLM._convert_to_bloom_cache(reordered_past)
508-
509-
@staticmethod
510-
def _convert_to_standard_cache(
511-
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int
512-
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
513-
"""
514-
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
515-
num_heads, ...]))
516-
"""
517-
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
518-
num_heads = batch_size_times_num_heads // batch_size
519-
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
520-
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
521-
return tuple(
522-
(
523-
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
524-
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
525-
)
526-
for layer_past in past_key_value
527-
)
528-
529-
@staticmethod
530-
def _convert_to_bloom_cache(
531-
past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]]
532-
) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]:
533-
"""
534-
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
535-
"""
536-
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
537-
batch_size_times_num_heads = batch_size * num_heads
538-
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
539-
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
540-
return tuple(
541-
(
542-
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
543-
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
544-
)
545-
for layer_past in past_key_value
546-
)
547-
548-
549-
class IPEXOPTForCausalLM(IPEXModelForCausalLM):
550-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
551-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
552-
if past_key_values is not None:
553-
past_length = past_key_values[0][0].shape[2]
554-
# Some generation methods already pass only the last input ID
555-
if input_ids.shape[1] > past_length:
556-
remove_prefix_length = past_length
557-
else:
558-
# Default to old behavior: keep only final ID
559-
remove_prefix_length = input_ids.shape[1] - 1
560-
input_ids = input_ids[:, remove_prefix_length:]
561-
562-
attention_mask = kwargs.get("attention_mask", None)
563-
use_cache = kwargs.get("use_cache", None)
564-
565-
return {
566-
"input_ids": input_ids,
567-
"past_key_values": past_key_values,
568-
"use_cache": use_cache,
569-
"position_ids": None,
570-
"attention_mask": attention_mask,
571-
}
572-
573-
574-
class IPEXMPTForCausalLM(IPEXModelForCausalLM):
575-
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
576-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
577-
if past_key_values is not None:
578-
past_length = past_key_values[0][0].shape[2]
579-
# Some generation methods already pass only the last input ID
580-
if input_ids.shape[1] > past_length:
581-
remove_prefix_length = past_length
582-
else:
583-
# Default to old behavior: keep only final ID
584-
remove_prefix_length = input_ids.shape[1] - 1
585-
input_ids = input_ids[:, remove_prefix_length:]
586-
587-
attention_mask = kwargs.get("attention_mask", None)
588-
use_cache = kwargs.get("use_cache", None)
589-
590-
return {
591-
"input_ids": input_ids,
592-
"past_key_values": past_key_values,
593-
"use_cache": use_cache,
594-
"position_ids": None,
595-
"attention_mask": attention_mask,
596-
}
597-
598-
599-
_MODEL_TYPE_TO_AUTOMODELS = {
600-
"bloom": IPEXBloomForCausalLM,
601-
"mpt": IPEXMPTForCausalLM,
602-
"opt": IPEXOPTForCausalLM,
603-
"gpt-bigcode": IPEXGPTBigCodeForCausalLM,
604-
}

0 commit comments

Comments
 (0)