Skip to content

Commit 7114cb2

Browse files
committed
add check for with past
1 parent 2b186f2 commit 7114cb2

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

optimum/exporters/openvino/convert.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,12 @@ def export_pytorch(
293293
logger.info(f"Using framework PyTorch: {torch.__version__}")
294294
output = Path(output)
295295

296-
if ensure_export_task_support_stateful(config.task):
297-
# Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect
298-
# both of them are applied to demonstrate the best performance.
299-
# TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation.
296+
task = config.task
297+
if getattr(config, "use_past", False):
298+
task += "-with-past"
299+
if ensure_export_task_support_stateful(task):
300+
# Trigger bettertransformer together with text-generation-with-past models because OpenVINO HW-dependent transformations expect
301+
# SDPA of them are applied to demonstrate the best performance.
300302
model = patch_model_with_bettertransformer(model)
301303
# TODO: Consider unpatching model after export is done in the end of this function.
302304
# Now it is left as-is because the model is not expected to be used after call export_pytorch, and

0 commit comments

Comments
 (0)