Skip to content

Commit 9d6d4bb

Browse files
committed
make input dynamic and enable sdpa
1 parent 0249b17 commit 9d6d4bb

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

optimum/exporters/openvino/__main__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ def main_export(
269269
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
270270
)
271271

272-
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
272+
if (
273+
is_transformers_version(">=", "4.36")
274+
and is_transformers_version("<=", "4.45.0")
275+
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
276+
):
273277
loading_kwargs["attn_implementation"] = "eager"
274278

275279
# some models force flash_attn attention by default that does not support load model on cpu

optimum/exporters/openvino/model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -2285,6 +2285,13 @@ def patch_model_for_export(
22852285
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
22862286
return super().patch_model_for_export(model, model_kwargs)
22872287

2288+
@property
2289+
def inputs(self):
2290+
common_inputs = super().inputs
2291+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2292+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
2293+
return common_inputs
2294+
22882295

22892296
@register_in_tasks_manager(
22902297
"t5",
@@ -2299,6 +2306,13 @@ def patch_model_for_export(
22992306
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
23002307
return super().patch_model_for_export(model, model_kwargs)
23012308

2309+
@property
2310+
def inputs(self):
2311+
common_inputs = super().inputs
2312+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2313+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
2314+
return common_inputs
2315+
23022316

23032317
@register_in_tasks_manager(
23042318
"mt5",

0 commit comments

Comments
 (0)