Skip to content

Commit 02d5e4e

Browse files
authored
Cover more models with openvino export (#709)
* cover more models with openvino export * xglm * fix tests
1 parent 3cfbc38 commit 02d5e4e

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

optimum/exporters/openvino/model_configs.py

+55
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,58 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
592592
return {
593593
"sample": {0: "batch_size", 2: "height", 3: "width"},
594594
}
595+
596+
597+
@register_in_tasks_manager(
598+
"persimmon",
599+
*[
600+
"feature-extraction",
601+
"feature-extraction-with-past",
602+
"text-generation",
603+
"text-generation-with-past",
604+
"text-classification",
605+
],
606+
library_name="transformers",
607+
)
608+
class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
609+
DEFAULT_ONNX_OPSET = 14
610+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
611+
612+
613+
@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers")
614+
class BioGPTOpenVINOConfig(TextDecoderOnnxConfig):
615+
# BioGPT does not require position_ids input.
616+
DEFAULT_ONNX_OPSET = 13
617+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
618+
619+
620+
@register_in_tasks_manager(
621+
"gpt-neox-japanese", *["text-generation", "text-generation-with-past"], library_name="transformers"
622+
)
623+
class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig):
624+
# GPTNeoxJapanese does not require position_ids input.
625+
DEFAULT_ONNX_OPSET = 13
626+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
627+
628+
629+
@register_in_tasks_manager(
630+
"cohere",
631+
*[
632+
"feature-extraction",
633+
"feature-extraction-with-past",
634+
"text-generation",
635+
"text-generation-with-past",
636+
"text-classification",
637+
],
638+
library_name="transformers",
639+
)
640+
class CohereOpenVINOConfig(LlamaOpenVINOConfig):
641+
pass
642+
643+
644+
@register_in_tasks_manager("xglm", *["text-generation", "text-generation-with-past"], library_name="transformers")
645+
class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig):
646+
DEFAULT_ONNX_OPSET = 13
647+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
648+
num_attention_heads="attention_heads", hidden_size="d_model"
649+
)

tests/openvino/test_modeling.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
552552
"orion",
553553
"falcon",
554554
"falcon-40b",
555+
"persimmon",
556+
"biogpt",
557+
"gpt_neox_japanese",
558+
"cohere",
559+
"xglm",
555560
)
556561
GENERATION_LENGTH = 100
557562
REMOTE_CODE_MODELS = (
@@ -617,8 +622,11 @@ def test_compare_to_transformers(self, model_arch):
617622
if model_arch == "qwen":
618623
return
619624

620-
if model_arch != "chatglm":
625+
if model_arch not in ["chatglm", "persimmon"]:
621626
tokenizer.pad_token_id = tokenizer.eos_token_id
627+
628+
if model_arch == "persimmon":
629+
tokenizer.pad_token_id = tokenizer.bos_token_id
622630
# Compare batched generation
623631
tokenizer.padding_side = "left"
624632
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)

tests/openvino/utils_tests.py

+5
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
"baichuan2": "katuni4ka/tiny-random-baichuan2",
2727
"baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b",
2828
"bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
29+
"biogpt": "hf-tiny-model-private/tiny-random-BioGptForCausalLM",
2930
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
3031
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
3132
"bloom": "hf-internal-testing/tiny-random-BloomModel",
3233
"camembert": "hf-internal-testing/tiny-random-camembert",
3334
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
35+
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
3436
"chatglm": "katuni4ka/tiny-random-chatglm2",
3537
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
3638
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
@@ -51,6 +53,7 @@
5153
"gpt2": "hf-internal-testing/tiny-random-gpt2",
5254
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
5355
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
56+
"gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM",
5457
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
5558
"hubert": "hf-internal-testing/tiny-random-HubertModel",
5659
"ibert": "hf-internal-testing/tiny-random-ibert",
@@ -78,6 +81,7 @@
7881
"olmo": "katuni4ka/tiny-random-olmo-hf",
7982
"orion": "katuni4ka/tiny-random-orion",
8083
"pegasus": "hf-internal-testing/tiny-random-pegasus",
84+
"persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM",
8185
"pix2struct": "fxmarty/pix2struct-tiny-random",
8286
"phi": "echarlaix/tiny-random-PhiForCausalLM",
8387
"phi3": "katuni4ka/tiny-random-phi3",
@@ -115,6 +119,7 @@
115119
"whisper": "openai/whisper-tiny.en",
116120
"xlm": "hf-internal-testing/tiny-random-xlm",
117121
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
122+
"xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM",
118123
}
119124

120125

0 commit comments

Comments
 (0)