Skip to content

Commit 6b9dc88

Browse files
committed
make input dynamic and enable sdpa
1 parent 0249b17 commit 6b9dc88

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

optimum/exporters/openvino/__main__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ def main_export(
269269
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
270270
)
271271

272-
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
272+
if (
273+
is_transformers_version(">=", "4.36")
274+
and is_transformers_version("<=", "4.45.0")
275+
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
276+
):
273277
loading_kwargs["attn_implementation"] = "eager"
274278

275279
# some models force flash_attn attention by default that does not support load model on cpu

optimum/exporters/openvino/convert.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@
2828
from openvino.tools.ovc import convert_model
2929
from optimum.exporters import TasksManager
3030
from optimum.exporters.utils import (
31-
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
32-
)
33-
from optimum.exporters.utils import (
34-
get_diffusion_models_for_export,
3531
DECODER_NAME,
3632
DECODER_WITH_PAST_NAME,
3733
ENCODER_NAME,
3834
_get_submodels_for_export_encoder_decoder,
35+
get_diffusion_models_for_export,
36+
)
37+
from optimum.exporters.utils import (
38+
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
3939
)
4040
from optimum.intel.utils.import_utils import (
4141
_diffusers_version,
@@ -47,7 +47,6 @@
4747
_torch_version,
4848
_transformers_version,
4949
compare_versions,
50-
is_openvino_version,
5150
is_openvino_tokenizers_version,
5251
is_tokenizers_version,
5352
is_transformers_version,

optimum/exporters/openvino/model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -2285,6 +2285,13 @@ def patch_model_for_export(
22852285
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
22862286
return super().patch_model_for_export(model, model_kwargs)
22872287

2288+
@property
2289+
def inputs(self):
2290+
common_inputs = super().inputs
2291+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2292+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
2293+
return common_inputs
2294+
22882295

22892296
@register_in_tasks_manager(
22902297
"t5",
@@ -2299,6 +2306,13 @@ def patch_model_for_export(
22992306
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
23002307
return super().patch_model_for_export(model, model_kwargs)
23012308

2309+
@property
2310+
def inputs(self):
2311+
common_inputs = super().inputs
2312+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2313+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
2314+
return common_inputs
2315+
23022316

23032317
@register_in_tasks_manager(
23042318
"mt5",

optimum/exporters/openvino/model_patcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from optimum.exporters.onnx.model_patcher import (
2828
DecoderModelPatcher,
2929
ModelPatcher,
30-
override_arguments,
3130
Seq2SeqModelPatcher,
31+
override_arguments,
3232
)
3333
from optimum.intel.utils.import_utils import (
3434
_openvino_version,

optimum/exporters/openvino/stateful.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def get_shape_of_ops(model: ov.Model):
242242

243243
def get_consumer_nodes(node):
244244
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
245-
return set(input.get_node() for input in consumer_inputs)
245+
return {input.get_node() for input in consumer_inputs}
246246

247247

248248
def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):

0 commit comments

Comments
 (0)