Skip to content

Commit 9c94e92

Browse files
committed
cover more models with openvino export
1 parent d021798 commit 9c94e92

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

optimum/exporters/openvino/model_configs.py

+47
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,50 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
592592
return {
593593
"sample": {0: "batch_size", 2: "height", 3: "width"},
594594
}
595+
596+
597+
@register_in_tasks_manager(
598+
"persimmon",
599+
*[
600+
"feature-extraction",
601+
"feature-extraction-with-past",
602+
"text-generation",
603+
"text-generation-with-past",
604+
"text-classification",
605+
],
606+
library_name="transformers",
607+
)
608+
class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
609+
DEFAULT_ONNX_OPSET = 14
610+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
611+
612+
613+
@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers")
614+
class BioGPTOpenVINOConfig(TextDecoderOnnxConfig):
615+
# BioGPT does not require position_ids input.
616+
DEFAULT_ONNX_OPSET = 13
617+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
618+
619+
620+
@register_in_tasks_manager(
621+
"gpt-neox-japanese", *["text-generation", "text-generation-with-past"], library_name="transformers"
622+
)
623+
class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig):
624+
# GPTNeoxJapanese does not require position_ids input.
625+
DEFAULT_ONNX_OPSET = 13
626+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
627+
628+
629+
@register_in_tasks_manager(
630+
"cohere",
631+
*[
632+
"feature-extraction",
633+
"feature-extraction-with-past",
634+
"text-generation",
635+
"text-generation-with-past",
636+
"text-classification",
637+
],
638+
library_name="transformers",
639+
)
640+
class CohereOpenVINOConfig(LlamaOpenVINOConfig):
641+
pass

tests/openvino/test_modeling.py

+4
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,10 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
552552
"orion",
553553
"falcon",
554554
"falcon-40b",
555+
"persimmon",
556+
"biogpt",
557+
"gpt_neox_japanese",
558+
"cohere",
555559
)
556560
GENERATION_LENGTH = 100
557561
REMOTE_CODE_MODELS = (

tests/openvino/utils_tests.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
"baichuan2": "katuni4ka/tiny-random-baichuan2",
2727
"baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b",
2828
"bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
29+
"biogpt": "hf-tiny-model-private/tiny-random-BioGptForCausalLM",
2930
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
3031
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
3132
"bloom": "hf-internal-testing/tiny-random-BloomModel",
3233
"camembert": "hf-internal-testing/tiny-random-camembert",
3334
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
35+
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
3436
"chatglm": "katuni4ka/tiny-random-chatglm2",
3537
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
3638
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
@@ -51,6 +53,7 @@
5153
"gpt2": "hf-internal-testing/tiny-random-gpt2",
5254
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
5355
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
56+
"gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM",
5457
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
5558
"hubert": "hf-internal-testing/tiny-random-HubertModel",
5659
"ibert": "hf-internal-testing/tiny-random-ibert",
@@ -78,6 +81,7 @@
7881
"olmo": "katuni4ka/tiny-random-olmo-hf",
7982
"orion": "katuni4ka/tiny-random-orion",
8083
"pegasus": "hf-internal-testing/tiny-random-pegasus",
84+
"persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM",
8185
"pix2struct": "fxmarty/pix2struct-tiny-random",
8286
"phi": "echarlaix/tiny-random-PhiForCausalLM",
8387
"phi3": "katuni4ka/tiny-random-phi3",

0 commit comments

Comments
 (0)