|
38 | 38 | )
|
39 | 39 | from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
40 | 40 | from transformers.utils import WEIGHTS_NAME
|
| 41 | +from transformers.dynamic_module_utils import get_class_from_dynamic_module |
| 42 | +from transformers.models.auto.auto_factory import _get_model_class as get_model_class |
41 | 43 |
|
42 | 44 | from optimum.exporters import TasksManager
|
43 | 45 | from optimum.modeling_base import OptimizedModel
|
@@ -164,12 +166,8 @@ def _from_pretrained(
|
164 | 166 |
|
165 | 167 | model = torch.jit.load(model_cache_path)
|
166 | 168 | torch.jit.freeze(model.eval())
|
167 |
| - model_type = config.model_type.replace("_", "-") |
168 |
| - init_cls = cls |
169 |
| - if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS: |
170 |
| - init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type] |
171 | 169 |
|
172 |
| - return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs) |
| 170 | + return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) |
173 | 171 |
|
174 | 172 | def _save_pretrained(self, save_directory: Union[str, Path]):
|
175 | 173 | output_path = os.path.join(save_directory, WEIGHTS_NAME)
|
@@ -302,6 +300,16 @@ def __init__(
|
302 | 300 | config.is_decoder = True
|
303 | 301 | config.is_encoder_decoder = False
|
304 | 302 | self.generation_config = GenerationConfig.from_model_config(config)
|
| 303 | + try: |
| 304 | + self.model_cls = get_class_from_dynamic_module(self.config.auto_map['AutoModelForCausalLM'], model_save_dir) |
| 305 | + except AttributeError: |
| 306 | + self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) |
| 307 | + self._reorder_cache = self.model_cls._reorder_cache.__get__(self) |
| 308 | + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) |
| 309 | + if hasattr(self.model_cls, '_convert_to_standard_cache'): |
| 310 | + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache |
| 311 | + if hasattr(self.model_cls, '_convert_to_bloom_cache'): |
| 312 | + self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache |
305 | 313 |
|
306 | 314 | def _prepare_past_key_values(self, input_ids):
|
307 | 315 | model_type = self.config.model_type.replace("_", "-")
|
@@ -378,227 +386,3 @@ def forward(
|
378 | 386 | past_key_values = outputs["past_key_values"] if self.use_cache else None
|
379 | 387 |
|
380 | 388 | return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
|
381 |
| - |
382 |
| - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation |
383 |
| - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
384 |
| - if past_key_values is not None: |
385 |
| - past_length = past_key_values[0][0].shape[2] |
386 |
| - # Some generation methods already pass only the last input ID |
387 |
| - if input_ids.shape[1] > past_length: |
388 |
| - remove_prefix_length = past_length |
389 |
| - else: |
390 |
| - # Default to old behavior: keep only final ID |
391 |
| - remove_prefix_length = input_ids.shape[1] - 1 |
392 |
| - input_ids = input_ids[:, remove_prefix_length:] |
393 |
| - |
394 |
| - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly |
395 |
| - attention_mask = kwargs.get("attention_mask", None) |
396 |
| - use_cache = kwargs.get("use_cache", None) |
397 |
| - position_ids = kwargs.get("position_ids", None) |
398 |
| - |
399 |
| - return { |
400 |
| - "input_ids": input_ids, |
401 |
| - "past_key_values": past_key_values, |
402 |
| - "use_cache": use_cache, |
403 |
| - "position_ids": position_ids, |
404 |
| - "attention_mask": attention_mask, |
405 |
| - } |
406 |
| - |
407 |
| - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache |
408 |
| - @staticmethod |
409 |
| - def _reorder_cache( |
410 |
| - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
411 |
| - ) -> Tuple[Tuple[torch.Tensor]]: |
412 |
| - return tuple( |
413 |
| - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) |
414 |
| - for layer_past in past_key_values |
415 |
| - ) |
416 |
| - |
417 |
| - |
418 |
| -class IPEXGPTBigCodeForCausalLM(IPEXModelForCausalLM): |
419 |
| - # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation |
420 |
| - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
421 |
| - # Omit tokens covered by past_key_values |
422 |
| - if past_key_values: |
423 |
| - if self.config.multi_query: |
424 |
| - past_length = past_key_values[0].shape[1] |
425 |
| - else: |
426 |
| - past_length = past_key_values[0].shape[2] |
427 |
| - |
428 |
| - # Some generation methods already pass only the last input ID |
429 |
| - if input_ids.shape[1] > past_length: |
430 |
| - remove_prefix_length = past_length |
431 |
| - else: |
432 |
| - # Default to old behavior: keep only final ID |
433 |
| - remove_prefix_length = input_ids.shape[1] - 1 |
434 |
| - |
435 |
| - input_ids = input_ids[:, remove_prefix_length:] |
436 |
| - |
437 |
| - attention_mask = kwargs.get("attention_mask", None) |
438 |
| - position_ids = kwargs.get("position_ids", None) |
439 |
| - |
440 |
| - if attention_mask is not None and position_ids is None: |
441 |
| - # create position_ids on the fly for batch generation |
442 |
| - position_ids = attention_mask.long().cumsum(-1) - 1 |
443 |
| - position_ids.masked_fill_(attention_mask == 0, 1) |
444 |
| - if past_key_values: |
445 |
| - position_ids = position_ids[:, -input_ids.shape[1] :] |
446 |
| - else: |
447 |
| - position_ids = None |
448 |
| - |
449 |
| - model_inputs = {"input_ids": input_ids} |
450 |
| - model_inputs.update( |
451 |
| - { |
452 |
| - "past_key_values": past_key_values, |
453 |
| - "use_cache": kwargs.get("use_cache"), |
454 |
| - "position_ids": position_ids, |
455 |
| - "attention_mask": attention_mask, |
456 |
| - } |
457 |
| - ) |
458 |
| - return model_inputs |
459 |
| - |
460 |
| - |
461 |
| -class IPEXBloomForCausalLM(IPEXModelForCausalLM): |
462 |
| - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation |
463 |
| - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
464 |
| - if past_key_values is not None: |
465 |
| - past_length = past_key_values[0][0].shape[2] |
466 |
| - # Some generation methods already pass only the last input ID |
467 |
| - if input_ids.shape[1] > past_length: |
468 |
| - remove_prefix_length = past_length |
469 |
| - else: |
470 |
| - # Default to old behavior: keep only final ID |
471 |
| - remove_prefix_length = input_ids.shape[1] - 1 |
472 |
| - input_ids = input_ids[:, remove_prefix_length:] |
473 |
| - |
474 |
| - attention_mask = kwargs.get("attention_mask", None) |
475 |
| - use_cache = kwargs.get("use_cache", None) |
476 |
| - |
477 |
| - # only last token for input_ids if past is not None |
478 |
| - if past_key_values: |
479 |
| - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed |
480 |
| - if past_key_values[0][0].shape[0] == input_ids.shape[0]: |
481 |
| - past_key_values = self._convert_to_bloom_cache(past_key_values) |
482 |
| - |
483 |
| - return { |
484 |
| - "input_ids": input_ids, |
485 |
| - "past_key_values": past_key_values, |
486 |
| - "use_cache": use_cache, |
487 |
| - "position_ids": None, |
488 |
| - "attention_mask": attention_mask, |
489 |
| - } |
490 |
| - |
491 |
| - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache |
492 |
| - @staticmethod |
493 |
| - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: |
494 |
| - standardized_past = IPEXModelForCausalLM._convert_to_standard_cache(past, batch_size=len(beam_idx)) |
495 |
| - |
496 |
| - # Get a copy of `beam_idx` on all the devices where we need those indices. |
497 |
| - device_to_beam_idx = { |
498 |
| - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past |
499 |
| - } |
500 |
| - reordered_past = tuple( |
501 |
| - ( |
502 |
| - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), |
503 |
| - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), |
504 |
| - ) |
505 |
| - for layer_past in standardized_past |
506 |
| - ) |
507 |
| - return IPEXModelForCausalLM._convert_to_bloom_cache(reordered_past) |
508 |
| - |
509 |
| - @staticmethod |
510 |
| - def _convert_to_standard_cache( |
511 |
| - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int |
512 |
| - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: |
513 |
| - """ |
514 |
| - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, |
515 |
| - num_heads, ...])) |
516 |
| - """ |
517 |
| - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape |
518 |
| - num_heads = batch_size_times_num_heads // batch_size |
519 |
| - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] |
520 |
| - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] |
521 |
| - return tuple( |
522 |
| - ( |
523 |
| - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), |
524 |
| - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), |
525 |
| - ) |
526 |
| - for layer_past in past_key_value |
527 |
| - ) |
528 |
| - |
529 |
| - @staticmethod |
530 |
| - def _convert_to_bloom_cache( |
531 |
| - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]] |
532 |
| - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: |
533 |
| - """ |
534 |
| - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) |
535 |
| - """ |
536 |
| - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape |
537 |
| - batch_size_times_num_heads = batch_size * num_heads |
538 |
| - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] |
539 |
| - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] |
540 |
| - return tuple( |
541 |
| - ( |
542 |
| - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), |
543 |
| - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), |
544 |
| - ) |
545 |
| - for layer_past in past_key_value |
546 |
| - ) |
547 |
| - |
548 |
| - |
549 |
| -class IPEXOPTForCausalLM(IPEXModelForCausalLM): |
550 |
| - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation |
551 |
| - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
552 |
| - if past_key_values is not None: |
553 |
| - past_length = past_key_values[0][0].shape[2] |
554 |
| - # Some generation methods already pass only the last input ID |
555 |
| - if input_ids.shape[1] > past_length: |
556 |
| - remove_prefix_length = past_length |
557 |
| - else: |
558 |
| - # Default to old behavior: keep only final ID |
559 |
| - remove_prefix_length = input_ids.shape[1] - 1 |
560 |
| - input_ids = input_ids[:, remove_prefix_length:] |
561 |
| - |
562 |
| - attention_mask = kwargs.get("attention_mask", None) |
563 |
| - use_cache = kwargs.get("use_cache", None) |
564 |
| - |
565 |
| - return { |
566 |
| - "input_ids": input_ids, |
567 |
| - "past_key_values": past_key_values, |
568 |
| - "use_cache": use_cache, |
569 |
| - "position_ids": None, |
570 |
| - "attention_mask": attention_mask, |
571 |
| - } |
572 |
| - |
573 |
| - |
574 |
| -class IPEXMPTForCausalLM(IPEXModelForCausalLM): |
575 |
| - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation |
576 |
| - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
577 |
| - if past_key_values is not None: |
578 |
| - past_length = past_key_values[0][0].shape[2] |
579 |
| - # Some generation methods already pass only the last input ID |
580 |
| - if input_ids.shape[1] > past_length: |
581 |
| - remove_prefix_length = past_length |
582 |
| - else: |
583 |
| - # Default to old behavior: keep only final ID |
584 |
| - remove_prefix_length = input_ids.shape[1] - 1 |
585 |
| - input_ids = input_ids[:, remove_prefix_length:] |
586 |
| - |
587 |
| - attention_mask = kwargs.get("attention_mask", None) |
588 |
| - use_cache = kwargs.get("use_cache", None) |
589 |
| - |
590 |
| - return { |
591 |
| - "input_ids": input_ids, |
592 |
| - "past_key_values": past_key_values, |
593 |
| - "use_cache": use_cache, |
594 |
| - "position_ids": None, |
595 |
| - "attention_mask": attention_mask, |
596 |
| - } |
597 |
| - |
598 |
| - |
599 |
| -_MODEL_TYPE_TO_AUTOMODELS = { |
600 |
| - "bloom": IPEXBloomForCausalLM, |
601 |
| - "mpt": IPEXMPTForCausalLM, |
602 |
| - "opt": IPEXOPTForCausalLM, |
603 |
| - "gpt-bigcode": IPEXGPTBigCodeForCausalLM, |
604 |
| -} |
0 commit comments