|
42 | 42 | from optimum.utils.normalized_config import NormalizedTextConfig
|
43 | 43 |
|
44 | 44 | from .model_patcher import (
|
| 45 | + AquilaModelPatcher, |
45 | 46 | BaichuanModelPatcher,
|
46 | 47 | ChatGLMModelPatcher,
|
47 | 48 | GemmaModelPatcher,
|
48 |
| - InternLMPatcher, |
| 49 | + InternLM2Patcher, |
| 50 | + InternLMModelPatcher, |
49 | 51 | LlamaModelPatcher,
|
50 | 52 | MixtralModelPatcher,
|
51 | 53 | MPTModelPatcher,
|
52 | 54 | Phi3ModelPatcher,
|
53 | 55 | QwenModelPatcher,
|
| 56 | + XverseModelPatcher, |
54 | 57 | )
|
55 | 58 |
|
56 | 59 |
|
@@ -461,7 +464,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
|
461 | 464 | def patch_model_for_export(
|
462 | 465 | self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
|
463 | 466 | ) -> "ModelPatcher":
|
464 |
| - return InternLMPatcher(self, model, model_kwargs=model_kwargs) |
| 467 | + return InternLM2Patcher(self, model, model_kwargs=model_kwargs) |
465 | 468 |
|
466 | 469 |
|
467 | 470 | @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
|
@@ -501,6 +504,12 @@ def patch_model_for_export(
|
501 | 504 | library_name="transformers",
|
502 | 505 | )
|
503 | 506 | class Phi3OpenVINOConfig(PhiOnnxConfig):
|
| 507 | + DUMMY_INPUT_GENERATOR_CLASSES = ( |
| 508 | + MistralDummyPastKeyValuesGenerator, |
| 509 | + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES |
| 510 | + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator |
| 511 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) |
| 512 | + |
504 | 513 | def patch_model_for_export(
|
505 | 514 | self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
|
506 | 515 | ) -> "ModelPatcher":
|
@@ -608,3 +617,140 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
|
608 | 617 | return {
|
609 | 618 | "sample": {0: "batch_size", 2: "height", 3: "width"},
|
610 | 619 | }
|
| 620 | + |
| 621 | + |
| 622 | +@register_in_tasks_manager( |
| 623 | + "persimmon", |
| 624 | + *[ |
| 625 | + "feature-extraction", |
| 626 | + "feature-extraction-with-past", |
| 627 | + "text-generation", |
| 628 | + "text-generation-with-past", |
| 629 | + "text-classification", |
| 630 | + ], |
| 631 | + library_name="transformers", |
| 632 | +) |
| 633 | +class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 634 | + DEFAULT_ONNX_OPSET = 14 |
| 635 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 636 | + |
| 637 | + |
| 638 | +@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 639 | +class BioGPTOpenVINOConfig(TextDecoderOnnxConfig): |
| 640 | + # BioGPT does not require position_ids input. |
| 641 | + DEFAULT_ONNX_OPSET = 13 |
| 642 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 643 | + |
| 644 | + |
| 645 | +@register_in_tasks_manager( |
| 646 | + "gpt-neox-japanese", *["text-generation", "text-generation-with-past"], library_name="transformers" |
| 647 | +) |
| 648 | +class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig): |
| 649 | + # GPTNeoxJapanese does not require position_ids input. |
| 650 | + DEFAULT_ONNX_OPSET = 13 |
| 651 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 652 | + |
| 653 | + |
| 654 | +@register_in_tasks_manager( |
| 655 | + "cohere", |
| 656 | + *[ |
| 657 | + "feature-extraction", |
| 658 | + "feature-extraction-with-past", |
| 659 | + "text-generation", |
| 660 | + "text-generation-with-past", |
| 661 | + "text-classification", |
| 662 | + ], |
| 663 | + library_name="transformers", |
| 664 | +) |
| 665 | +class CohereOpenVINOConfig(LlamaOpenVINOConfig): |
| 666 | + pass |
| 667 | + |
| 668 | + |
| 669 | +@register_in_tasks_manager("xglm", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 670 | +class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 671 | + DEFAULT_ONNX_OPSET = 13 |
| 672 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( |
| 673 | + num_attention_heads="attention_heads", hidden_size="d_model" |
| 674 | + ) |
| 675 | + |
| 676 | + |
| 677 | +class AquilaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): |
| 678 | + def __init__( |
| 679 | + self, |
| 680 | + task: str, |
| 681 | + normalized_config: NormalizedTextConfig, |
| 682 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 683 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 684 | + random_batch_size_range: Optional[Tuple[int, int]] = None, |
| 685 | + random_sequence_length_range: Optional[Tuple[int, int]] = None, |
| 686 | + **kwargs, |
| 687 | + ): |
| 688 | + super().__init__( |
| 689 | + task, |
| 690 | + normalized_config, |
| 691 | + batch_size, |
| 692 | + sequence_length, |
| 693 | + random_batch_size_range, |
| 694 | + random_sequence_length_range, |
| 695 | + **kwargs, |
| 696 | + ) |
| 697 | + self.num_key_value_heads = getattr( |
| 698 | + normalized_config, "num_key_value_heads", normalized_config.num_attention_heads |
| 699 | + ) |
| 700 | + |
| 701 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 702 | + shape = ( |
| 703 | + self.batch_size, |
| 704 | + self.num_key_value_heads, |
| 705 | + self.sequence_length, |
| 706 | + self.hidden_size // self.num_attention_heads, |
| 707 | + ) |
| 708 | + return [ |
| 709 | + ( |
| 710 | + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), |
| 711 | + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), |
| 712 | + ) |
| 713 | + for _ in range(self.num_layers) |
| 714 | + ] |
| 715 | + |
| 716 | + |
| 717 | +@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 718 | +class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 719 | + DEFAULT_ONNX_OPSET = 14 |
| 720 | + |
| 721 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, AquilaDummyPastKeyValuesGenerator) |
| 722 | + DUMMY_PKV_GENERATOR_CLASS = AquilaDummyPastKeyValuesGenerator |
| 723 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) |
| 724 | + |
| 725 | + def patch_model_for_export( |
| 726 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 727 | + ) -> "ModelPatcher": |
| 728 | + return AquilaModelPatcher(self, model, model_kwargs=model_kwargs) |
| 729 | + |
| 730 | + |
| 731 | +@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 732 | +class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 733 | + DEFAULT_ONNX_OPSET = 14 |
| 734 | + |
| 735 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) |
| 736 | + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator |
| 737 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 738 | + |
| 739 | + def patch_model_for_export( |
| 740 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 741 | + ) -> "ModelPatcher": |
| 742 | + return XverseModelPatcher(self, model, model_kwargs=model_kwargs) |
| 743 | + |
| 744 | + |
| 745 | +@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 746 | +class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 747 | + DEFAULT_ONNX_OPSET = 14 |
| 748 | + |
| 749 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) |
| 750 | + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator |
| 751 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 752 | + |
| 753 | + def patch_model_for_export( |
| 754 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 755 | + ) -> "ModelPatcher": |
| 756 | + return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) |
0 commit comments