24
24
from optimum .exporters .onnx .model_configs import (
25
25
CLIPOnnxConfig ,
26
26
CLIPTextOnnxConfig ,
27
+ CLIPTextWithProjectionOnnxConfig ,
28
+ CLIPVisionModelOnnxConfig ,
27
29
CodeGenOnnxConfig ,
28
30
FalconOnnxConfig ,
29
31
GemmaOnnxConfig ,
35
37
PhiOnnxConfig ,
36
38
VisionOnnxConfig ,
37
39
)
40
+ from optimum .exporters .onnx .model_patcher import ModelPatcher
38
41
from optimum .exporters .tasks import TasksManager
39
42
from optimum .utils import DEFAULT_DUMMY_SHAPES
40
43
from optimum .utils .input_generators import (
@@ -1079,6 +1082,11 @@ def generate_dummy_inputs_for_validation(
1079
1082
reference_model_inputs ["text" ] = reference_model_inputs .pop ("input_ids" )
1080
1083
return super ().generate_dummy_inputs_for_validation (reference_model_inputs )
1081
1084
1085
+ def patch_model_for_export (
1086
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1087
+ ) -> ModelPatcher :
1088
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1089
+
1082
1090
1083
1091
@register_in_tasks_manager ("clip-text-model" , * ["feature-extraction" ], library_name = "open_clip" )
1084
1092
class OpenCLIPTextOpenVINOConfig (CLIPTextOnnxConfig ):
@@ -1109,6 +1117,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
1109
1117
dummy_inputs = super ().generate_dummy_inputs (framework = framework , ** kwargs )
1110
1118
return dummy_inputs
1111
1119
1120
+ def patch_model_for_export (
1121
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1122
+ ) -> ModelPatcher :
1123
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1124
+
1112
1125
1113
1126
@register_in_tasks_manager ("clip-vision-model" , * ["feature-extraction" ], library_name = "open_clip" )
1114
1127
class OpenCLIPVisualOpenVINOConfig (VisionOnnxConfig ):
@@ -1134,6 +1147,42 @@ def rename_ambiguous_inputs(self, inputs):
1134
1147
return model_inputs
1135
1148
1136
1149
1150
+ @register_in_tasks_manager (
1151
+ "clip" , * ["feature-extraction" , "zero-shot-image-classification" ], library_name = "transformers"
1152
+ )
1153
+ class CLIPOpenVINOConfig (CLIPOnnxConfig ):
1154
+ def patch_model_for_export (
1155
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1156
+ ) -> ModelPatcher :
1157
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1158
+
1159
+
1160
+ @register_in_tasks_manager ("clip-text-model" , * ["feature-extraction" ], library_name = "transformers" )
1161
+ @register_in_tasks_manager ("clip-text-model" , * ["feature-extraction" ], library_name = "diffusers" )
1162
+ class CLIPTextOpenVINOConfig (CLIPTextOnnxConfig ):
1163
+ def patch_model_for_export (
1164
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1165
+ ) -> ModelPatcher :
1166
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1167
+
1168
+
1169
+ @register_in_tasks_manager ("clip-text-with-projection" , * ["feature-extraction" ], library_name = "transformers" )
1170
+ @register_in_tasks_manager ("clip-text-with-projection" , * ["feature-extraction" ], library_name = "diffusers" )
1171
+ class CLIPTextWithProjectionOpenVINOConfig (CLIPTextWithProjectionOnnxConfig ):
1172
+ def patch_model_for_export (
1173
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1174
+ ) -> ModelPatcher :
1175
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1176
+
1177
+
1178
+ @register_in_tasks_manager ("clip-vision-model" , * ["feature-extraction" ], library_name = "transformers" )
1179
+ class CLIPVisionModelOpenVINOConfig (CLIPVisionModelOnnxConfig ):
1180
+ def patch_model_for_export (
1181
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
1182
+ ) -> ModelPatcher :
1183
+ return ModelPatcher (self , model , model_kwargs = model_kwargs )
1184
+
1185
+
1137
1186
@register_in_tasks_manager (
1138
1187
"ibert" ,
1139
1188
* [
0 commit comments