Skip to content

Commit e57baac

Browse files
committed
Add openvino export configs and support chatglm
1 parent f52d7c8 commit e57baac

File tree

7 files changed

+337
-27
lines changed

7 files changed

+337
-27
lines changed
+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from .__main__ import main_export
2+
from .base import init_model_configs
23
from .convert import export, export_models, export_pytorch_via_onnx
4+
from .model_configs import *
35

46

7+
init_model_configs()
8+
59
__all__ = ["main_export", "export", "export_models"]

optimum/exporters/openvino/__main__.py

+120-23
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@
1515
import logging
1616
import os
1717
from pathlib import Path
18-
from typing import Any, Callable, Dict, Optional, Union
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1919

2020
from requests.exceptions import ConnectionError as RequestsConnectionError
2121
from transformers import AutoTokenizer
2222

2323
from optimum.exporters import TasksManager
24-
from optimum.exporters.onnx import __main__ as optimum_main
2524
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
25+
from optimum.exporters.onnx.utils import (
26+
_get_submodels_for_export_encoder_decoder,
27+
_get_submodels_for_export_stable_diffusion,
28+
get_encoder_decoder_models_for_export,
29+
get_sam_models_for_export,
30+
get_stable_diffusion_models_for_export,
31+
)
2632
from optimum.utils import DEFAULT_DUMMY_SHAPES
2733
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
2834

@@ -31,13 +37,113 @@
3137
from .convert import export_models
3238

3339

40+
if TYPE_CHECKING:
41+
from transformers import PreTrainedModel, TFPreTrainedModel
42+
43+
3444
OV_XML_FILE_NAME = "openvino_model.xml"
3545

3646
_MAX_UNCOMPRESSED_SIZE = 1e9
3747

3848
logger = logging.getLogger(__name__)
3949

4050

51+
def _get_submodels_and_export_configs(
52+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
53+
task: str,
54+
custom_onnx_configs: Dict,
55+
custom_architecture: bool,
56+
_variant: str,
57+
int_dtype: str = "int64",
58+
float_dtype: str = "fp32",
59+
fn_get_submodels: Optional[Callable] = None,
60+
preprocessors: Optional[List[Any]] = None,
61+
no_position_ids: bool = False,
62+
):
63+
is_stable_diffusion = "stable-diffusion" in task
64+
if not custom_architecture:
65+
if is_stable_diffusion:
66+
onnx_config = None
67+
models_and_onnx_configs = get_stable_diffusion_models_for_export(
68+
model, int_dtype=int_dtype, float_dtype=float_dtype
69+
)
70+
else:
71+
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
72+
model=model, exporter="openvino", task=task
73+
)
74+
onnx_config_kwargs = {}
75+
if task.startswith("text-generation") and no_position_ids:
76+
onnx_config_kwargs["no_position_ids"] = no_position_ids
77+
78+
onnx_config = onnx_config_constructor(
79+
model.config,
80+
int_dtype=int_dtype,
81+
float_dtype=float_dtype,
82+
preprocessors=preprocessors,
83+
**onnx_config_kwargs,
84+
)
85+
86+
onnx_config.variant = _variant
87+
all_variants = "\n".join(
88+
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
89+
)
90+
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")
91+
92+
if model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS):
93+
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
94+
elif task.startswith("text-generation"):
95+
model = patch_decoder_attention_mask(model)
96+
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
97+
model=model, exporter="openvino", task=task
98+
)
99+
onnx_config = onnx_config_constructor(model.config)
100+
models_and_onnx_configs = {"model": (model, onnx_config)}
101+
elif model.config.model_type == "sam":
102+
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
103+
else:
104+
models_and_onnx_configs = {"model": (model, onnx_config)}
105+
106+
# When specifying custom ONNX configs for supported transformers architectures, we do
107+
# not force to specify a custom ONNX config for each submodel.
108+
for key, custom_onnx_config in custom_onnx_configs.items():
109+
models_and_onnx_configs[key] = (models_and_onnx_configs[key][0], custom_onnx_config)
110+
else:
111+
onnx_config = None
112+
submodels_for_export = None
113+
models_and_onnx_configs = {}
114+
115+
if fn_get_submodels is not None:
116+
submodels_for_export = fn_get_submodels(model)
117+
else:
118+
if is_stable_diffusion:
119+
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
120+
elif model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS):
121+
submodels_for_export = _get_submodels_for_export_encoder_decoder(
122+
model, use_past=task.endswith("-with-past")
123+
)
124+
elif task.startswith("text-generation"):
125+
model = patch_decoder_attention_mask(model)
126+
models_and_onnx_configs = {"model": model}
127+
else:
128+
submodels_for_export = {"model": model}
129+
130+
if submodels_for_export.keys() != custom_onnx_configs.keys():
131+
logger.error(f"ONNX custom configs for: {', '.join(custom_onnx_configs.keys())}")
132+
logger.error(f"Submodels to export: {', '.join(submodels_for_export.keys())}")
133+
raise ValueError(
134+
"Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary."
135+
)
136+
137+
for key, custom_onnx_config in custom_onnx_configs.items():
138+
models_and_onnx_configs[key] = (submodels_for_export[key], custom_onnx_config)
139+
140+
# Default to the first ONNX config for stable-diffusion and custom architecture case.
141+
if onnx_config is None:
142+
onnx_config = next(iter(models_and_onnx_configs.values()))[1]
143+
144+
return onnx_config, models_and_onnx_configs
145+
146+
41147
def main_export(
42148
model_name_or_path: str,
43149
output: Union[str, Path],
@@ -183,7 +289,7 @@ def main_export(
183289
f"If you want to support {model_type} please propose a PR or open up an issue."
184290
)
185291
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
186-
task, exporter="onnx"
292+
task, exporter="openvino"
187293
):
188294
custom_architecture = True
189295

@@ -200,7 +306,7 @@ def main_export(
200306
if (
201307
not custom_architecture
202308
and not is_stable_diffusion
203-
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
309+
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino")
204310
):
205311
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
206312
task = task + "-with-past"
@@ -222,24 +328,15 @@ def main_export(
222328
preprocessors = maybe_load_preprocessors(
223329
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
224330
)
225-
if not task.startswith("text-generation"):
226-
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
227-
model=model,
228-
task=task,
229-
monolith=False,
230-
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
231-
custom_architecture=custom_architecture,
232-
fn_get_submodels=fn_get_submodels,
233-
preprocessors=preprocessors,
234-
_variant="default",
235-
)
236-
else:
237-
# TODO : ModelPatcher will be added in next optimum release
238-
model = patch_decoder_attention_mask(model)
239-
240-
onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
241-
onnx_config = onnx_config_constructor(model.config)
242-
models_and_onnx_configs = {"model": (model, onnx_config)}
331+
onnx_config, models_and_onnx_configs = _get_submodels_and_export_configs(
332+
model=model,
333+
task=task,
334+
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
335+
custom_architecture=custom_architecture,
336+
fn_get_submodels=fn_get_submodels,
337+
preprocessors=preprocessors,
338+
_variant="default",
339+
)
243340

244341
if int8 is None:
245342
int8 = False
@@ -276,7 +373,7 @@ def main_export(
276373
generation_config = getattr(model, "generation_config", None)
277374
if generation_config is not None:
278375
generation_config.save_pretrained(output)
279-
maybe_save_preprocessors(model_name_or_path, output)
376+
maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code)
280377

281378
if model.config.is_encoder_decoder and task.startswith("text-generation"):
282379
raise ValueError(

optimum/exporters/openvino/base.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from copy import deepcopy
2+
from typing import Callable, Type
3+
4+
from optimum.exporters.tasks import TasksManager
5+
from optimum.utils.normalized_config import NormalizedConfigManager
6+
7+
8+
def init_model_configs():
9+
suppored_models = TasksManager._SUPPORTED_MODEL_TYPE
10+
for model, export_configs in suppored_models.items():
11+
if "onnx" not in export_configs:
12+
continue
13+
TasksManager._SUPPORTED_MODEL_TYPE[model]["openvino"] = deepcopy(
14+
TasksManager._SUPPORTED_MODEL_TYPE[model]["onnx"]
15+
)
16+
17+
18+
def register_normalized_config(model_type: str) -> Callable[[Type], Type]:
19+
def decorator(config_cls: Type) -> Type:
20+
if model_type in NormalizedConfigManager._conf:
21+
return config_cls
22+
NormalizedConfigManager._conf[model_type] = config_cls
23+
return config_cls
24+
25+
return decorator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import Optional, Tuple
2+
3+
from optimum.utils import (
4+
DEFAULT_DUMMY_SHAPES,
5+
DummyPastKeyValuesGenerator,
6+
DummyTextInputGenerator,
7+
NormalizedTextConfig,
8+
)
9+
10+
11+
class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator):
12+
SUPPORTED_INPUT_NAMES = {
13+
"input_ids",
14+
"attention_mask",
15+
"token_type_ids",
16+
"position_ids",
17+
}
18+
19+
20+
class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
21+
def __init__(
22+
self,
23+
task: str,
24+
normalized_config: NormalizedTextConfig,
25+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
26+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
27+
random_batch_size_range: Optional[Tuple[int, int]] = None,
28+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
29+
**kwargs,
30+
):
31+
super().__init__(
32+
task=task,
33+
normalized_config=normalized_config,
34+
batch_size=batch_size,
35+
sequence_length=sequence_length,
36+
random_batch_size_range=random_batch_size_range,
37+
random_sequence_length_range=random_sequence_length_range,
38+
)
39+
self.multi_query_group_num = normalized_config.multi_query_group_num
40+
self.head_dim = self.hidden_size // self.num_attention_heads
41+
42+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
43+
past_key_shape = (
44+
self.sequence_length,
45+
self.batch_size,
46+
self.multi_query_group_num,
47+
self.head_dim,
48+
)
49+
past_value_shape = (
50+
self.sequence_length,
51+
self.batch_size,
52+
self.multi_query_group_num,
53+
self.head_dim,
54+
)
55+
return [
56+
(
57+
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
58+
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
59+
)
60+
for _ in range(self.num_layers)
61+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Callable, Dict, Type
15+
16+
from optimum.exporters.onnx import TextDecoderOnnxConfig
17+
from optimum.exporters.tasks import TasksManager, make_backend_config_constructor_for_task
18+
19+
from .dummy_input_generators import ChatGLM2DummyPastKeyValuesGenerator, ChatGLN2DummyTextInputGenerator
20+
from .normalized_configs import ChatGLM2NormalizedConfig
21+
22+
23+
def create_register(overwrite_existing: bool = False):
24+
def wrapper(model_type: str, *supported_tasks: str) -> Callable[[Type], Type]:
25+
def decorator(config_cls: Type) -> Type:
26+
mapping = TasksManager._SUPPORTED_MODEL_TYPE.get(model_type, {})
27+
mapping_backend = mapping.get("openvino", {})
28+
for task in supported_tasks:
29+
normalized_task = task
30+
if "-with-past" in task:
31+
normalized_task = task.split("-with-past")[0]
32+
if normalized_task not in TasksManager.get_all_tasks():
33+
known_tasks = ", ".join(TasksManager.get_all_tasks())
34+
raise ValueError(
35+
f'The TasksManager does not know the task called "{task}", known tasks: {known_tasks}.'
36+
)
37+
if not overwrite_existing and task in mapping_backend:
38+
continue
39+
mapping_backend[task] = make_backend_config_constructor_for_task(config_cls, task)
40+
mapping["openvino"] = mapping_backend
41+
TasksManager._SUPPORTED_MODEL_TYPE[model_type] = mapping
42+
return config_cls
43+
44+
return decorator
45+
46+
return wrapper
47+
48+
49+
register_in_tasks_manager = create_register(True)
50+
51+
52+
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
53+
class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
54+
NORMALIZED_CONFIG_CLASS = ChatGLM2NormalizedConfig
55+
DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLN2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
56+
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
57+
no_position_ids = False
58+
59+
@property
60+
def inputs(self) -> Dict[str, Dict[int, str]]:
61+
common_inputs = super().inputs
62+
common_inputs.pop("attention_mask")
63+
if not self.no_position_ids and self.task == "text-generation":
64+
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
65+
66+
return common_inputs
67+
68+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
69+
"""
70+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
71+
72+
Args:
73+
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
74+
The mapping to fill.
75+
direction (`str`):
76+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
77+
output mapping, this is important for axes naming.
78+
"""
79+
if direction not in ["inputs", "outputs"]:
80+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
81+
82+
if direction == "inputs":
83+
decoder_sequence_name = "past_sequence_length"
84+
name = "past_key_values"
85+
else:
86+
decoder_sequence_name = "past_sequence_length + 1"
87+
name = "present"
88+
89+
for i in range(self._normalized_config.num_layers):
90+
inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name}
91+
inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from optimum.utils import NormalizedTextConfig
2+
3+
from .base import register_normalized_config
4+
5+
6+
@register_normalized_config("chatglm")
7+
class ChatGLM2NormalizedConfig(NormalizedTextConfig):
8+
NUM_LAYERS = "num_layers"
9+
VOCAB_SIZE = "padded_vocab_size"

0 commit comments

Comments
 (0)