Skip to content

Commit 84cffb4

Browse files
committed
allow to use SDPA in clip models
1 parent 2559620 commit 84cffb4

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

optimum/exporters/openvino/model_configs.py

+49
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from optimum.exporters.onnx.model_configs import (
2525
CLIPOnnxConfig,
2626
CLIPTextOnnxConfig,
27+
CLIPTextWithProjectionOnnxConfig,
28+
CLIPVisionModelOnnxConfig,
2729
CodeGenOnnxConfig,
2830
FalconOnnxConfig,
2931
GemmaOnnxConfig,
@@ -35,6 +37,7 @@
3537
PhiOnnxConfig,
3638
VisionOnnxConfig,
3739
)
40+
from optimum.exporters.onnx.model_patcher import ModelPatcher
3841
from optimum.exporters.tasks import TasksManager
3942
from optimum.utils import DEFAULT_DUMMY_SHAPES
4043
from optimum.utils.input_generators import (
@@ -1079,6 +1082,11 @@ def generate_dummy_inputs_for_validation(
10791082
reference_model_inputs["text"] = reference_model_inputs.pop("input_ids")
10801083
return super().generate_dummy_inputs_for_validation(reference_model_inputs)
10811084

1085+
def patch_model_for_export(
1086+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1087+
) -> ModelPatcher:
1088+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1089+
10821090

10831091
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="open_clip")
10841092
class OpenCLIPTextOpenVINOConfig(CLIPTextOnnxConfig):
@@ -1109,6 +1117,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
11091117
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
11101118
return dummy_inputs
11111119

1120+
def patch_model_for_export(
1121+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1122+
) -> ModelPatcher:
1123+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1124+
11121125

11131126
@register_in_tasks_manager("clip-vision-model", *["feature-extraction"], library_name="open_clip")
11141127
class OpenCLIPVisualOpenVINOConfig(VisionOnnxConfig):
@@ -1134,6 +1147,42 @@ def rename_ambiguous_inputs(self, inputs):
11341147
return model_inputs
11351148

11361149

1150+
@register_in_tasks_manager(
1151+
"clip", *["feature-extraction", "zero-shot-image-classification"], library_name="transformers"
1152+
)
1153+
class CLIPOpenVINOConfig(CLIPOnnxConfig):
1154+
def patch_model_for_export(
1155+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1156+
) -> ModelPatcher:
1157+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1158+
1159+
1160+
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="transformers")
1161+
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers")
1162+
class CLIPTextOpenVINOConfig(CLIPTextOnnxConfig):
1163+
def patch_model_for_export(
1164+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1165+
) -> ModelPatcher:
1166+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1167+
1168+
1169+
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="transformers")
1170+
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers")
1171+
class CLIPTextWithProjectionOpenVINOConfig(CLIPTextWithProjectionOnnxConfig):
1172+
def patch_model_for_export(
1173+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1174+
) -> ModelPatcher:
1175+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1176+
1177+
1178+
@register_in_tasks_manager("clip-vision-model", *["feature-extraction"], library_name="transformers")
1179+
class CLIPVisionModelOpenVINOConfig(CLIPVisionModelOnnxConfig):
1180+
def patch_model_for_export(
1181+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1182+
) -> ModelPatcher:
1183+
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1184+
1185+
11371186
@register_in_tasks_manager(
11381187
"ibert",
11391188
*[

0 commit comments

Comments
 (0)