Skip to content

Commit b08610f

Browse files
committed
add openvino export configs
1 parent 8f7d016 commit b08610f

File tree

6 files changed

+107
-24
lines changed

6 files changed

+107
-24
lines changed
+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from .__main__ import main_export
2+
from .base import init_model_configs
23
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
4+
from .model_configs import *
35
from .stateful import ensure_stateful_is_available, patch_stateful
46

57

8+
init_model_configs()
9+
10+
611
__all__ = ["main_export", "export", "export_models"]

optimum/exporters/openvino/base.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from copy import deepcopy
16+
17+
from optimum.exporters.tasks import TasksManager
18+
19+
20+
def init_model_configs():
21+
suppored_models = TasksManager._SUPPORTED_MODEL_TYPE
22+
for model, export_configs in suppored_models.items():
23+
if "onnx" not in export_configs:
24+
continue
25+
TasksManager._SUPPORTED_MODEL_TYPE[model]["openvino"] = deepcopy(
26+
TasksManager._SUPPORTED_MODEL_TYPE[model]["onnx"]
27+
)

optimum/exporters/openvino/convert.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
3333
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
3434
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
35+
from optimum.exporters.utils import _get_submodels_and_export_configs
3536
from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available
3637
from optimum.utils.save_utils import maybe_save_preprocessors
3738

38-
from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
39+
from ...intel.utils.import_utils import is_nncf_available
3940
from .model_patcher import patch_model_with_bettertransformer
4041
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
4142
from .utils import (
@@ -48,13 +49,6 @@
4849
)
4950

5051

51-
if is_optimum_version(">=", "1.16.99"):
52-
from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs
53-
54-
else:
55-
from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs
56-
57-
5852
UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast)
5953

6054

@@ -458,7 +452,7 @@ def ts_patched_forward(*args, **kwargs):
458452

459453

460454
def export_models(
461-
models_and_onnx_configs: Dict[
455+
models_and_export_configs: Dict[
462456
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
463457
],
464458
output_dir: Path,
@@ -475,7 +469,7 @@ def export_models(
475469
Export the models to OpenVINO IR format
476470
477471
Args:
478-
models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
472+
models_and_export_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
479473
output_dir (Path): output directory for saving models
480474
opset (Optional[int], optional, Default to None): ONNX export opset
481475
output_names (Optional[List[str]], optional, Defaults to None): model output names
@@ -504,20 +498,20 @@ def export_models(
504498
# TODO : modify compression_option to quantization_config
505499
outputs = []
506500

507-
if output_names is not None and len(output_names) != len(models_and_onnx_configs):
501+
if output_names is not None and len(output_names) != len(models_and_export_configs):
508502
raise ValueError(
509-
f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export."
503+
f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export."
510504
)
511505

512-
for i, model_name in enumerate(models_and_onnx_configs.keys()):
513-
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
506+
for i, model_name in enumerate(models_and_export_configs.keys()):
507+
submodel, sub_export_config = models_and_export_configs[model_name]
514508
output_name = output_names[i] if output_names is not None else Path(model_name + ".xml")
515509
output_path = output_dir / output_name
516510
output_path.parent.mkdir(parents=True, exist_ok=True)
517511
outputs.append(
518512
export(
519513
model=submodel,
520-
config=sub_onnx_config,
514+
config=sub_export_config,
521515
output=output_path,
522516
opset=opset,
523517
device=device,
@@ -621,7 +615,7 @@ def export_from_model(
621615
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
622616
)
623617

624-
onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
618+
export_config, models_and_export_configs = _get_submodels_and_export_configs(
625619
model=model,
626620
task=task,
627621
monolith=False,
@@ -633,6 +627,7 @@ def export_from_model(
633627
model_kwargs=model_kwargs,
634628
_variant="default",
635629
legacy=False,
630+
exporter="openvino",
636631
)
637632

638633
if compression_option is None:
@@ -661,18 +656,18 @@ def export_from_model(
661656
model_name_or_path = model.config._name_or_path
662657
maybe_save_preprocessors(model_name_or_path, output)
663658

664-
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()]
659+
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
665660

666661
else:
667662
# save the subcomponent configuration
668-
for model_name in models_and_onnx_configs:
669-
subcomponent = models_and_onnx_configs[model_name][0]
663+
for model_name in models_and_export_configs:
664+
subcomponent = models_and_export_configs[model_name][0]
670665
if hasattr(subcomponent, "save_config"):
671666
subcomponent.save_config(output / model_name)
672667
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
673668
subcomponent.config.save_pretrained(output / model_name)
674669

675-
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs]
670+
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_export_configs]
676671

677672
# Saving the additional components needed to perform inference.
678673
model.scheduler.save_pretrained(output.joinpath("scheduler"))
@@ -692,7 +687,7 @@ def export_from_model(
692687
model.save_config(output)
693688

694689
export_models(
695-
models_and_onnx_configs=models_and_onnx_configs,
690+
models_and_export_configs=models_and_export_configs,
696691
output_dir=output,
697692
output_names=files_subpaths,
698693
input_shapes=input_shapes,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
17+
from optimum.exporters.tasks import TasksManager
18+
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
19+
from optimum.utils.normalized_config import NormalizedTextConfig
20+
21+
22+
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)
23+
24+
25+
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"])
26+
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
27+
DEFAULT_ONNX_OPSET = 13
28+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
29+
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
30+
)
31+
32+
33+
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"])
34+
class JaisOpenVINOConfig(TextDecoderOnnxConfig):
35+
DEFAULT_ONNX_OPSET = 13
36+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
37+
num_layers="n_layer", num_attention_heads="n_head", hidden_size="n_embd"
38+
)
39+
40+
41+
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"])
42+
class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
43+
DEFAULT_ONNX_OPSET = 14
44+
45+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
46+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
47+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
48+
49+
50+
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"])
51+
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
52+
DEFAULT_ONNX_OPSET = 14
53+
54+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
55+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
56+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

optimum/intel/openvino/quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def _quantize_torchmodel(
335335

336336
model_type = self.model.config.model_type.replace("_", "-")
337337
onnx_config_class = TasksManager.get_exporter_config_constructor(
338-
exporter="onnx",
338+
exporter="openvino",
339339
model=self.model,
340340
task=self.task,
341341
model_type=model_type,

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
INSTALL_REQUIRE = [
1515
"torch>=1.11",
16-
"optimum>=1.17.0",
16+
"optimum @ git+https://github.com/eaidova/optimum.git@ea/move_model_preparation#egg=optimum",
1717
"transformers>=4.26.0",
1818
"datasets>=1.4.0",
1919
"sentencepiece",
@@ -50,7 +50,7 @@
5050
"onnx",
5151
"onnxruntime",
5252
"transformers>=4.36.0",
53-
"optimum>=1.16.1",
53+
"optimum @ git+https://github.com/eaidova/optimum.git@ea/move_model_preparation#egg=optimum"
5454
],
5555
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
5656
"nncf": ["nncf>=2.8.1"],

0 commit comments

Comments
 (0)