Skip to content

Commit f34bd61

Browse files
jane-intelIlyasMoutawwakileaidovaecharlaix
authored
[OV] Sets symbols on inputs and variables (#696)
* Prepare shape with symbols before conversion * Bumping lower bound of openvino version due to backward incompatible optimum change * style * style Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> * style * style * Style * rename function * fix test --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> Co-authored-by: Ella Charlaix <ella@huggingface.co>
1 parent 5b10499 commit f34bd61

File tree

2 files changed

+56
-53
lines changed

2 files changed

+56
-53
lines changed

optimum/exporters/openvino/convert.py

+13-27
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import functools
1616
import gc
17-
import inspect
1817
import logging
1918
import os
2019
from pathlib import Path
@@ -23,9 +22,8 @@
2322
import onnx
2423
from transformers.utils import is_tf_available, is_torch_available
2524

26-
from openvino.runtime import Model, PartialShape, save_model
25+
from openvino.runtime import Model, save_model
2726
from openvino.runtime.exceptions import OVTypeError
28-
from openvino.runtime.utils.types import get_element_type
2927
from openvino.tools.ovc import convert_model
3028
from optimum.exporters import TasksManager
3129
from optimum.exporters.onnx.base import OnnxConfig
@@ -50,9 +48,8 @@
5048
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
5149
from .utils import (
5250
OV_XML_FILE_NAME,
51+
_get_input_info,
5352
clear_class_registry,
54-
flattenize_inputs,
55-
get_input_shapes,
5653
remove_none_from_dummy_inputs,
5754
)
5855

@@ -374,13 +371,12 @@ def ts_patched_forward(*args, **kwargs):
374371

375372
__make_16bit_traceable(model)
376373
check_dummy_inputs_are_allowed(model, dummy_inputs)
377-
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
378-
inputs = config.ordered_inputs(model)
379-
input_names = list(inputs.keys())
380-
output_names = list(config.outputs.keys())
381-
input_info = get_input_shapes(dummy_inputs, inputs)
382-
383-
ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
374+
input_info = _get_input_info(model, config, dummy_inputs)
375+
ov_model = convert_model(
376+
model,
377+
example_input=dummy_inputs,
378+
input=[(item.shape, item.type) for item in input_info],
379+
)
384380

385381
except Exception as ex:
386382
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
@@ -411,27 +407,17 @@ def ts_patched_forward(*args, **kwargs):
411407
ov_config=ov_config,
412408
)
413409

414-
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
415-
if not ordered_dummy_inputs:
416-
ordered_dummy_inputs = dummy_inputs
417-
ordered_input_names = list(inputs)
418-
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
419-
ov_model.validate_nodes_and_infer_types()
410+
ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?
411+
412+
output_names = list(config.outputs.keys())
420413
for idx, out_tensor in enumerate(ov_model.outputs):
421414
if idx < len(output_names):
422415
out_tensor.get_tensor().set_names({output_names[idx]})
423416

417+
input_names = [item.name for item in input_info]
424418
for idx, inp_tensor in enumerate(ov_model.inputs):
425-
input_name = ordered_input_names[idx]
419+
input_name = input_names[idx]
426420
inp_tensor.get_tensor().set_names({input_name})
427-
inp_data = flatten_inputs[idx]
428-
static_shape = PartialShape(inp_data.shape)
429-
dims = inputs.get(input_name, [])
430-
for dim in dims:
431-
static_shape[dim] = -1
432-
inp_tensor.get_node().set_partial_shape(static_shape)
433-
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
434-
ov_model.validate_nodes_and_infer_types()
435421

436422
if stateful:
437423
patch_stateful(model.config, ov_model)

optimum/exporters/openvino/utils.py

+43-26
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
16+
from collections import namedtuple
1517
from typing import Any, Dict, List, Tuple, Union
1618

1719
from transformers.utils import is_torch_available
1820

19-
from openvino.runtime import PartialShape
21+
from openvino.runtime import Dimension, PartialShape, Symbol
22+
from openvino.runtime.utils.types import get_element_type
23+
from optimum.exporters.onnx.base import OnnxConfig
2024
from optimum.utils import is_diffusers_available
2125

2226

27+
InputInfo = namedtuple("InputInfo", ["name", "shape", "type", "example"])
28+
29+
2330
if is_torch_available():
2431
import torch
2532
import torch.nn as nn
@@ -69,6 +76,41 @@ def flattenize_inputs(inputs: List[Any]):
6976
return flatten_inputs
7077

7178

79+
def _get_input_info(
80+
model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]
81+
) -> List[InputInfo]:
82+
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
83+
inputs = config.ordered_inputs(model)
84+
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
85+
if not ordered_dummy_inputs:
86+
ordered_dummy_inputs = dummy_inputs
87+
ordered_input_names = list(inputs)
88+
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
89+
input_info = []
90+
91+
name_to_symbol = {}
92+
93+
for i in range(len(ordered_input_names)):
94+
name = ordered_input_names[i]
95+
example = flatten_inputs[i]
96+
type = get_element_type(example.cpu().numpy().dtype)
97+
shape = PartialShape(example.shape)
98+
if name in inputs:
99+
named_dims = inputs[name]
100+
for idx, dim_name in named_dims.items():
101+
if dim_name in name_to_symbol:
102+
symbol = name_to_symbol[dim_name]
103+
else:
104+
symbol = Symbol()
105+
name_to_symbol[name] = symbol
106+
dim = Dimension(-1)
107+
dim.set_symbol(symbol)
108+
shape[idx] = dim
109+
info = InputInfo(name=name, shape=shape, type=type, example=example)
110+
input_info.append(info)
111+
return input_info
112+
113+
72114
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
73115
"""
74116
Removes None values from the dictionary.
@@ -109,31 +151,6 @@ def remove_none_from_list_tuple(item: Union[List[Any], Tuple[Any]]):
109151
return upd_dummy, dict_dummy
110152

111153

112-
def get_input_shapes(dummy_inputs: Dict[str, Any], inputs: Dict[str, Any]):
113-
"""
114-
Resolves input shapes based on dynamic axes from input config and dummy input shapes
115-
116-
Args:
117-
dummy_inputs (Dict[str, Any]): A dictionary of dummy inputs.
118-
inputs (Dict[str, Any]): A dictionary of input tensors.
119-
120-
Returns:
121-
input_info: List of input info for conversion
122-
123-
"""
124-
input_info = []
125-
for input_name, data in dummy_inputs.items():
126-
if isinstance(data, (tuple, list, dict)):
127-
return None
128-
static_shape = PartialShape(data.shape)
129-
if input_name in inputs:
130-
dynamic_dims = inputs[input_name]
131-
for dim in dynamic_dims:
132-
static_shape[dim] = -1
133-
input_info.append((input_name, static_shape))
134-
return input_info
135-
136-
137154
def clear_class_registry():
138155
"""
139156
Removes Torchscript cached modules

0 commit comments

Comments
 (0)