Skip to content

Commit 4ce6472

Browse files
committed
add orion
1 parent 3b3dd9f commit 4ce6472

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

optimum/exporters/openvino/__main__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def main_export(
5050
device: str = "cpu",
5151
framework: Optional[str] = None,
5252
cache_dir: Optional[str] = None,
53-
trust_remote_code: bool = False,
53+
trust_remote_code: bool = None,
5454
pad_token_id: Optional[int] = None,
5555
subfolder: str = "",
5656
revision: str = "main",
@@ -203,6 +203,8 @@ def main_export(
203203
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
204204
model_type = config.model_type.replace("_", "-")
205205

206+
if model_type in {"falcon", "mpt"} and trust_remote_code:
207+
trust_remote_code = False
206208
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
207209
custom_architecture = True
208210
elif task not in TasksManager.get_supported_tasks_for_model_type(

optimum/exporters/openvino/model_configs.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def init_model_configs():
7575

7676

7777
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
78-
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
78+
class BaichaunOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
7979
DEFAULT_ONNX_OPSET = 13
8080
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
8181
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
@@ -471,3 +471,12 @@ class DeciOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
471471
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DeciDummyPastKeyValuesGenerator)
472472
DUMMY_PKV_GENERATOR_CLASS = DeciDummyPastKeyValuesGenerator
473473
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
474+
475+
476+
@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
477+
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
478+
DEFAULT_ONNX_OPSET = 14
479+
480+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
481+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
482+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

0 commit comments

Comments
 (0)