Skip to content

Commit 87c431c

Browse files
authored
restore input format for stable diffusion and export configs mapping (#1091)
* restore input format for stable diffusion * update configs registration * fix shapes for timestep * align names for t5
1 parent 29b2ac9 commit 87c431c

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

optimum/exporters/openvino/model_configs.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,7 @@ def patch_model_for_export(
12501250

12511251
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="transformers")
12521252
@register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers")
1253+
@register_in_tasks_manager("clip-text", *["feature-extraction"], library_name="diffusers")
12531254
class CLIPTextOpenVINOConfig(CLIPTextOnnxConfig):
12541255
def patch_model_for_export(
12551256
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
@@ -1795,12 +1796,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
17951796
)
17961797

17971798

1799+
class DummyUnetTimestepInputGenerator(DummyTimestepInputGenerator):
1800+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1801+
if input_name != "timestep":
1802+
return super().generate(input_name, framework, int_dtype, float_dtype)
1803+
shape = [self.batch_size]
1804+
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype)
1805+
1806+
17981807
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
1808+
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
17991809
class UnetOpenVINOConfig(UNetOnnxConfig):
1800-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
1810+
DUMMY_INPUT_GENERATOR_CLASSES = (
1811+
DummyUnetVisionInputGenerator,
1812+
DummyUnetTimestepInputGenerator,
1813+
) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]
1814+
1815+
@property
1816+
def inputs(self) -> Dict[str, Dict[int, str]]:
1817+
common_inputs = super().inputs
1818+
common_inputs["timestep"] = {0: "batch_size"}
1819+
return common_inputs
18011820

18021821

18031822
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
1823+
@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
18041824
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
18051825
DUMMY_INPUT_GENERATOR_CLASSES = (
18061826
(DummyTransformerTimestpsInputGenerator,)
@@ -1830,6 +1850,7 @@ def rename_ambiguous_inputs(self, inputs):
18301850

18311851

18321852
@register_in_tasks_manager("t5-encoder-model", *["feature-extraction"], library_name="diffusers")
1853+
@register_in_tasks_manager("t5-encoder", *["feature-extraction"], library_name="diffusers")
18331854
class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
18341855
pass
18351856

@@ -1905,6 +1926,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
19051926

19061927

19071928
@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
1929+
@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
19081930
class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
19091931
DUMMY_INPUT_GENERATOR_CLASSES = (
19101932
DummyTransformerTimestpsInputGenerator,

0 commit comments

Comments
 (0)