Skip to content

Commit fd19842

Browse files
committed
add orion
1 parent 59c3967 commit fd19842

File tree

4 files changed

+16
-2
lines changed

4 files changed

+16
-2
lines changed

optimum/exporters/openvino/__main__.py

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def main_export(
203203
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
204204
model_type = config.model_type.replace("_", "-")
205205

206+
if model_type in {"falcon", "mpt"} and trust_remote_code:
207+
trust_remote_code = False
206208
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
207209
custom_architecture = True
208210
elif task not in TasksManager.get_supported_tasks_for_model_type(

optimum/exporters/openvino/model_configs.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def init_model_configs():
7575

7676

7777
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
78-
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
78+
class BaichaunOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
7979
DEFAULT_ONNX_OPSET = 13
8080
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
8181
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
@@ -471,3 +471,12 @@ class DeciOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
471471
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeciDummyPastKeyValuesGenerator)
472472
DUMMY_PKV_GENERATOR_CLASS = DeciDummyPastKeyValuesGenerator
473473
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
474+
475+
476+
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
477+
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
478+
DEFAULT_ONNX_OPSET = 14
479+
480+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
481+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
482+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

tests/openvino/test_modeling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -521,10 +521,11 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
521521
"stablelm",
522522
"starcoder2",
523523
"phi",
524+
"internlm2",
524525
)
525526
GENERATION_LENGTH = 100
526527
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
527-
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen")
528+
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "decilm")
528529

529530
@parameterized.expand(SUPPORTED_ARCHITECTURES)
530531
def test_compare_to_transformers(self, model_arch):

tests/openvino/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
3737
"deberta": "hf-internal-testing/tiny-random-deberta",
3838
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
39+
"decilm": "katuni4ka/tiny-random-deciml",
3940
"deit": "hf-internal-testing/tiny-random-deit",
4041
"convnext": "hf-internal-testing/tiny-random-convnext",
4142
"distilbert": "hf-internal-testing/tiny-random-distilbert",
@@ -49,6 +50,7 @@
4950
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
5051
"hubert": "hf-internal-testing/tiny-random-HubertModel",
5152
"ibert": "hf-internal-testing/tiny-random-ibert",
53+
"internlm2": "katuni4ka/tiny-random-internlm2",
5254
"levit": "hf-internal-testing/tiny-random-LevitModel",
5355
"longt5": "hf-internal-testing/tiny-random-longt5",
5456
"llama": "fxmarty/tiny-llama-fast-tokenizer",

0 commit comments

Comments
 (0)