Skip to content

Commit 4d4f137

Browse files
committed
restore input format for stable diffusion
1 parent ea6fa42 commit 4d4f137

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

optimum/exporters/openvino/model_configs.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,14 @@ def patch_model_for_export(
12561256
) -> ModelPatcher:
12571257
return ModelPatcher(self, model, model_kwargs=model_kwargs)
12581258

1259+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
1260+
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
1261+
# TODO: fix should be by casting inputs during inference and not export
1262+
if framework == "pt":
1263+
import torch
1264+
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
1265+
return dummy_inputs
1266+
12591267

12601268
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="transformers")
12611269
@register_in_tasks_manager("clip-text-with-projection", *["feature-extraction"], library_name="diffusers")
@@ -1795,9 +1803,17 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
17951803
)
17961804

17971805

1806+
class DummyUnetTimestepInputGenerator(DummyTimestepInputGenerator):
1807+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1808+
if input_name != "timestep":
1809+
return super().generate(input_name, framework, int_dtype, float_dtype)
1810+
shape = [self.batch_size]
1811+
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype)
1812+
1813+
17981814
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
17991815
class UnetOpenVINOConfig(UNetOnnxConfig):
1800-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
1816+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator, DummyUnetTimestepInputGenerator) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:]
18011817

18021818

18031819
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")

0 commit comments

Comments
 (0)