|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import inspect |
| 16 | +from collections import namedtuple |
15 | 17 | from typing import Any, Dict, List, Tuple, Union
|
16 | 18 |
|
17 | 19 | from transformers.utils import is_torch_available
|
18 | 20 |
|
19 |
| -from openvino.runtime import PartialShape |
| 21 | +from openvino.runtime import Dimension, PartialShape, Symbol |
| 22 | +from openvino.runtime.utils.types import get_element_type |
| 23 | +from optimum.exporters.onnx.base import OnnxConfig |
20 | 24 | from optimum.utils import is_diffusers_available
|
21 | 25 |
|
22 | 26 |
|
| 27 | +InputInfo = namedtuple("InputInfo", ["name", "shape", "type", "example"]) |
| 28 | + |
| 29 | + |
23 | 30 | if is_torch_available():
|
24 | 31 | import torch
|
25 | 32 | import torch.nn as nn
|
@@ -69,6 +76,41 @@ def flattenize_inputs(inputs: List[Any]):
|
69 | 76 | return flatten_inputs
|
70 | 77 |
|
71 | 78 |
|
| 79 | +def _get_input_info( |
| 80 | + model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any] |
| 81 | +) -> List[InputInfo]: |
| 82 | + sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) |
| 83 | + inputs = config.ordered_inputs(model) |
| 84 | + ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} |
| 85 | + if not ordered_dummy_inputs: |
| 86 | + ordered_dummy_inputs = dummy_inputs |
| 87 | + ordered_input_names = list(inputs) |
| 88 | + flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values()) |
| 89 | + input_info = [] |
| 90 | + |
| 91 | + name_to_symbol = {} |
| 92 | + |
| 93 | + for i in range(len(ordered_input_names)): |
| 94 | + name = ordered_input_names[i] |
| 95 | + example = flatten_inputs[i] |
| 96 | + type = get_element_type(example.cpu().numpy().dtype) |
| 97 | + shape = PartialShape(example.shape) |
| 98 | + if name in inputs: |
| 99 | + named_dims = inputs[name] |
| 100 | + for idx, dim_name in named_dims.items(): |
| 101 | + if dim_name in name_to_symbol: |
| 102 | + symbol = name_to_symbol[dim_name] |
| 103 | + else: |
| 104 | + symbol = Symbol() |
| 105 | + name_to_symbol[name] = symbol |
| 106 | + dim = Dimension(-1) |
| 107 | + dim.set_symbol(symbol) |
| 108 | + shape[idx] = dim |
| 109 | + info = InputInfo(name=name, shape=shape, type=type, example=example) |
| 110 | + input_info.append(info) |
| 111 | + return input_info |
| 112 | + |
| 113 | + |
72 | 114 | def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
|
73 | 115 | """
|
74 | 116 | Removes None values from the dictionary.
|
@@ -109,31 +151,6 @@ def remove_none_from_list_tuple(item: Union[List[Any], Tuple[Any]]):
|
109 | 151 | return upd_dummy, dict_dummy
|
110 | 152 |
|
111 | 153 |
|
112 |
| -def get_input_shapes(dummy_inputs: Dict[str, Any], inputs: Dict[str, Any]): |
113 |
| - """ |
114 |
| - Resolves input shapes based on dynamic axes from input config and dummy input shapes |
115 |
| -
|
116 |
| - Args: |
117 |
| - dummy_inputs (Dict[str, Any]): A dictionary of dummy inputs. |
118 |
| - inputs (Dict[str, Any]): A dictionary of input tensors. |
119 |
| -
|
120 |
| - Returns: |
121 |
| - input_info: List of input info for conversion |
122 |
| -
|
123 |
| - """ |
124 |
| - input_info = [] |
125 |
| - for input_name, data in dummy_inputs.items(): |
126 |
| - if isinstance(data, (tuple, list, dict)): |
127 |
| - return None |
128 |
| - static_shape = PartialShape(data.shape) |
129 |
| - if input_name in inputs: |
130 |
| - dynamic_dims = inputs[input_name] |
131 |
| - for dim in dynamic_dims: |
132 |
| - static_shape[dim] = -1 |
133 |
| - input_info.append((input_name, static_shape)) |
134 |
| - return input_info |
135 |
| - |
136 |
| - |
137 | 154 | def clear_class_registry():
|
138 | 155 | """
|
139 | 156 | Removes Torchscript cached modules
|
|
0 commit comments