Skip to content

Commit 60d5bf6

Browse files
authored
Add support export for new architectures (#716)
* support export more models * update aquila to support v1 and v2
1 parent 715c054 commit 60d5bf6

File tree

4 files changed

+378
-5
lines changed

4 files changed

+378
-5
lines changed

optimum/exporters/openvino/model_configs.py

+87-2
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,18 @@
4141
from optimum.utils.normalized_config import NormalizedTextConfig
4242

4343
from .model_patcher import (
44+
AquilaModelPatcher,
4445
BaichuanModelPatcher,
4546
ChatGLMModelPatcher,
4647
GemmaModelPatcher,
47-
InternLMPatcher,
48+
InternLM2Patcher,
49+
InternLMModelPatcher,
4850
LlamaModelPatcher,
4951
MixtralModelPatcher,
5052
MPTModelPatcher,
5153
Phi3ModelPatcher,
5254
QwenModelPatcher,
55+
XverseModelPatcher,
5356
)
5457

5558

@@ -445,7 +448,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
445448
def patch_model_for_export(
446449
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
447450
) -> "ModelPatcher":
448-
return InternLMPatcher(self, model, model_kwargs=model_kwargs)
451+
return InternLM2Patcher(self, model, model_kwargs=model_kwargs)
449452

450453

451454
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
@@ -653,3 +656,85 @@ class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig):
653656
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
654657
num_attention_heads="attention_heads", hidden_size="d_model"
655658
)
659+
660+
661+
class AquilaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
662+
def __init__(
663+
self,
664+
task: str,
665+
normalized_config: NormalizedTextConfig,
666+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
667+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
668+
random_batch_size_range: Optional[Tuple[int, int]] = None,
669+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
670+
**kwargs,
671+
):
672+
super().__init__(
673+
task,
674+
normalized_config,
675+
batch_size,
676+
sequence_length,
677+
random_batch_size_range,
678+
random_sequence_length_range,
679+
**kwargs,
680+
)
681+
self.num_key_value_heads = getattr(
682+
normalized_config, "num_key_value_heads", normalized_config.num_attention_heads
683+
)
684+
685+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
686+
shape = (
687+
self.batch_size,
688+
self.num_key_value_heads,
689+
self.sequence_length,
690+
self.hidden_size // self.num_attention_heads,
691+
)
692+
return [
693+
(
694+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
695+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
696+
)
697+
for _ in range(self.num_layers)
698+
]
699+
700+
701+
@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers")
702+
class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
703+
DEFAULT_ONNX_OPSET = 14
704+
705+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, AquilaDummyPastKeyValuesGenerator)
706+
DUMMY_PKV_GENERATOR_CLASS = AquilaDummyPastKeyValuesGenerator
707+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
708+
709+
def patch_model_for_export(
710+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
711+
) -> "ModelPatcher":
712+
return AquilaModelPatcher(self, model, model_kwargs=model_kwargs)
713+
714+
715+
@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers")
716+
class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
717+
DEFAULT_ONNX_OPSET = 14
718+
719+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
720+
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
721+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
722+
723+
def patch_model_for_export(
724+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
725+
) -> "ModelPatcher":
726+
return XverseModelPatcher(self, model, model_kwargs=model_kwargs)
727+
728+
729+
@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers")
730+
class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
731+
DEFAULT_ONNX_OPSET = 14
732+
733+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
734+
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator
735+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
736+
737+
def patch_model_for_export(
738+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
739+
) -> "ModelPatcher":
740+
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)

0 commit comments

Comments
 (0)