Skip to content

Commit 8880d2e

Browse files
authored
Add openvino export configs (#568)
* add openvino export configs * more libs * more libs * mixtral and model patcher * chatglm export * rework chatglm config * more testing models * rework config registration * add chatglm in tests * Update tests/openvino/test_modeling.py * fix style * gemma * add test models * qwen * fix failed tests * add comment for gemma
1 parent 2588077 commit 8880d2e

File tree

10 files changed

+933
-58
lines changed

10 files changed

+933
-58
lines changed

optimum/exporters/openvino/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import optimum.exporters.openvino.model_configs
16+
1517
from .__main__ import main_export
1618
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
1719
from .stateful import ensure_stateful_is_available, patch_stateful

optimum/exporters/openvino/__main__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main_export(
5858
local_files_only: bool = False,
5959
use_auth_token: Optional[Union[bool, str]] = None,
6060
model_kwargs: Optional[Dict[str, Any]] = None,
61-
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
61+
custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None,
6262
fn_get_submodels: Optional[Callable] = None,
6363
compression_option: Optional[str] = None,
6464
compression_ratio: Optional[float] = None,
@@ -112,11 +112,11 @@ def main_export(
112112
when running `transformers-cli login` (stored in `~/.huggingface`).
113113
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
114114
Experimental usage: keyword arguments to pass to the model during
115-
the export. This argument should be used along the `custom_onnx_configs` argument
115+
the export. This argument should be used along the `custom_export_configs` argument
116116
in case, for example, the model inputs/outputs are changed (for example, if
117117
`model_kwargs={"output_attentions": True}` is passed).
118-
custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
119-
Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
118+
custom_export_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
119+
Experimental usage: override the default export config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
120120
fn_get_submodels (`Optional[Callable]`, defaults to `None`):
121121
Experimental usage: Override the default submodels that are used at the export. This is
122122
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
@@ -134,7 +134,7 @@ def main_export(
134134
```python
135135
>>> from optimum.exporters.openvino import main_export
136136
137-
>>> main_export("gpt2", output="gpt2_onnx/")
137+
>>> main_export("gpt2", output="gpt2_ov/")
138138
```
139139
"""
140140

@@ -206,14 +206,14 @@ def main_export(
206206
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
207207
custom_architecture = True
208208
elif task not in TasksManager.get_supported_tasks_for_model_type(
209-
model_type, exporter="onnx", library_name=library_name
209+
model_type, exporter="openvino", library_name=library_name
210210
):
211211
if original_task == "auto":
212212
autodetected_message = " (auto-detected)"
213213
else:
214214
autodetected_message = ""
215215
model_tasks = TasksManager.get_supported_tasks_for_model_type(
216-
model_type, exporter="onnx", library_name=library_name
216+
model_type, exporter="openvino", library_name=library_name
217217
)
218218
raise ValueError(
219219
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
@@ -288,7 +288,7 @@ class StoreAttr(object):
288288
not custom_architecture
289289
and library_name != "diffusers"
290290
and task + "-with-past"
291-
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx", library_name=library_name)
291+
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name)
292292
):
293293
# Make -with-past the default if --task was not explicitely specified
294294
if original_task == "auto":
@@ -319,7 +319,7 @@ class StoreAttr(object):
319319
ov_config=ov_config,
320320
stateful=stateful,
321321
model_kwargs=model_kwargs,
322-
custom_onnx_configs=custom_onnx_configs,
322+
custom_export_configs=custom_export_configs,
323323
fn_get_submodels=fn_get_submodels,
324324
preprocessors=preprocessors,
325325
device=device,

optimum/exporters/openvino/convert.py

+22-27
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
3333
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
3434
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
35+
from optimum.exporters.utils import _get_submodels_and_export_configs
3536
from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available
3637
from optimum.utils.save_utils import maybe_save_preprocessors
3738

38-
from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
39+
from ...intel.utils.import_utils import is_nncf_available
3940
from .model_patcher import patch_model_with_bettertransformer
4041
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
4142
from .utils import (
@@ -48,13 +49,6 @@
4849
)
4950

5051

51-
if is_optimum_version(">=", "1.16.99"):
52-
from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs
53-
54-
else:
55-
from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs
56-
57-
5852
UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast)
5953

6054

@@ -418,7 +412,7 @@ def ts_patched_forward(*args, **kwargs):
418412

419413

420414
def export_models(
421-
models_and_onnx_configs: Dict[
415+
models_and_export_configs: Dict[
422416
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
423417
],
424418
output_dir: Path,
@@ -434,7 +428,7 @@ def export_models(
434428
Export the models to OpenVINO IR format
435429
436430
Args:
437-
models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
431+
models_and_export_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
438432
output_dir (Path): output directory for saving models
439433
opset (Optional[int], optional, Default to None): ONNX export opset
440434
output_names (Optional[List[str]], optional, Defaults to None): model output names
@@ -459,20 +453,20 @@ def export_models(
459453

460454
outputs = []
461455

462-
if output_names is not None and len(output_names) != len(models_and_onnx_configs):
456+
if output_names is not None and len(output_names) != len(models_and_export_configs):
463457
raise ValueError(
464-
f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export."
458+
f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export."
465459
)
466460

467-
for i, model_name in enumerate(models_and_onnx_configs.keys()):
468-
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
461+
for i, model_name in enumerate(models_and_export_configs.keys()):
462+
submodel, sub_export_config = models_and_export_configs[model_name]
469463
output_name = output_names[i] if output_names is not None else Path(model_name + ".xml")
470464
output_path = output_dir / output_name
471465
output_path.parent.mkdir(parents=True, exist_ok=True)
472466
outputs.append(
473467
export(
474468
model=submodel,
475-
config=sub_onnx_config,
469+
config=sub_export_config,
476470
output=output_path,
477471
opset=opset,
478472
device=device,
@@ -495,7 +489,7 @@ def export_from_model(
495489
stateful: bool = True,
496490
opset: Optional[int] = None,
497491
model_kwargs: Optional[Dict[str, Any]] = None,
498-
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
492+
custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None,
499493
fn_get_submodels: Optional[Callable] = None,
500494
preprocessors: List = None,
501495
device: str = "cpu",
@@ -524,14 +518,14 @@ def export_from_model(
524518
task = TasksManager._infer_task_from_model_or_model_class(model=model)
525519
except (ValueError, KeyError) as e:
526520
raise RuntimeError(
527-
f"The model task could not be automatically inferred in `onnx_export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
521+
f"The model task could not be automatically inferred in `export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
528522
)
529523

530524
if (
531525
not custom_architecture
532526
and library_name != "diffusers"
533527
and task + "-with-past"
534-
in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx", library_name=library_name)
528+
in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino", library_name=library_name)
535529
):
536530
# -with-past is the default.
537531
task = task + "-with-past"
@@ -541,9 +535,9 @@ def export_from_model(
541535
stateful = stateful and ensure_export_task_support_stateful(task)
542536

543537
# TODO: support onnx_config.py in the model repo
544-
if custom_architecture and custom_onnx_configs is None:
538+
if custom_architecture and custom_export_configs is None:
545539
raise ValueError(
546-
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
540+
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
547541
)
548542

549543
if task.startswith("text-generation") and model.config.is_encoder_decoder:
@@ -569,18 +563,19 @@ def export_from_model(
569563
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
570564
)
571565

572-
onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
566+
export_config, models_and_export_configs = _get_submodels_and_export_configs(
573567
model=model,
574568
task=task,
575569
monolith=False,
576-
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
570+
custom_export_configs=custom_export_configs if custom_export_configs is not None else {},
577571
custom_architecture=custom_architecture,
578572
fn_get_submodels=fn_get_submodels,
579573
preprocessors=preprocessors,
580574
library_name=library_name,
581575
model_kwargs=model_kwargs,
582576
_variant="default",
583577
legacy=False,
578+
exporter="openvino",
584579
)
585580

586581
if ov_config is None:
@@ -612,18 +607,18 @@ def export_from_model(
612607
model_name_or_path = model.config._name_or_path
613608
maybe_save_preprocessors(model_name_or_path, output)
614609

615-
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()]
610+
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
616611

617612
else:
618613
# save the subcomponent configuration
619-
for model_name in models_and_onnx_configs:
620-
subcomponent = models_and_onnx_configs[model_name][0]
614+
for model_name in models_and_export_configs:
615+
subcomponent = models_and_export_configs[model_name][0]
621616
if hasattr(subcomponent, "save_config"):
622617
subcomponent.save_config(output / model_name)
623618
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
624619
subcomponent.config.save_pretrained(output / model_name)
625620

626-
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs]
621+
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_export_configs]
627622

628623
# Saving the additional components needed to perform inference.
629624
model.scheduler.save_pretrained(output.joinpath("scheduler"))
@@ -643,7 +638,7 @@ def export_from_model(
643638
model.save_config(output)
644639

645640
export_models(
646-
models_and_onnx_configs=models_and_onnx_configs,
641+
models_and_export_configs=models_and_export_configs,
647642
output_dir=output,
648643
output_names=files_subpaths,
649644
input_shapes=input_shapes,

0 commit comments

Comments
 (0)