Skip to content

Commit ca5103d

Browse files
committed
update configs registration
1 parent 4d4f137 commit ca5103d

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

optimum/exporters/openvino/model_configs.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1250,20 +1250,13 @@ def patch_model_for_export(
12501250

12511251
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="transformers")
12521252
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers")
1253+
@register_in_tasks_manager("clip-text", *["feature-extraction"], library_name="diffusers")
12531254
class CLIPTextOpenVINOConfig(CLIPTextOnnxConfig):
12541255
def patch_model_for_export(
12551256
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
12561257
) -> ModelPatcher:
12571258
return ModelPatcher(self, model, model_kwargs=model_kwargs)
12581259

1259-
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
1260-
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
1261-
# TODO: fix should be by casting inputs during inference and not export
1262-
if framework == "pt":
1263-
import torch
1264-
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
1265-
return dummy_inputs
1266-
12671260

12681261
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="transformers")
12691262
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers")
@@ -1812,11 +1805,16 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18121805

18131806

18141807
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
1808+
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
18151809
class UnetOpenVINOConfig(UNetOnnxConfig):
1816-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator, DummyUnetTimestepInputGenerator) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]
1810+
DUMMY_INPUT_GENERATOR_CLASSES = (
1811+
DummyUnetVisionInputGenerator,
1812+
DummyUnetTimestepInputGenerator,
1813+
) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]
18171814

18181815

18191816
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
1817+
@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
18201818
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
18211819
DUMMY_INPUT_GENERATOR_CLASSES = (
18221820
(DummyTransformerTimestpsInputGenerator,)
@@ -1921,6 +1919,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
19211919

19221920

19231921
@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
1922+
@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
19241923
class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
19251924
DUMMY_INPUT_GENERATOR_CLASSES = (
19261925
DummyTransformerTimestpsInputGenerator,

0 commit comments

Comments
 (0)