Skip to content

Commit 420fa87

Browse files
authored
support any input resolution in stable diffusion models (#1087)
* support any input resolution in stable diffusion models * Update optimum/exporters/openvino/model_configs.py
1 parent 8ef3997 commit 420fa87

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

optimum/exporters/openvino/model_configs.py

+17
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,23 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
17831783
return super().generate(input_name, framework, int_dtype, float_dtype)
17841784

17851785

1786+
class DummyUnetVisionInputGenerator(DummyVisionInputGenerator):
1787+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1788+
if input_name not in ["sample", "latent_sample"]:
1789+
return super().generate(input_name, framework, int_dtype, float_dtype)
1790+
# add height and width discount for enable any resolution generation
1791+
return self.random_float_tensor(
1792+
shape=[self.batch_size, self.num_channels, self.height - 1, self.width - 1],
1793+
framework=framework,
1794+
dtype=float_dtype,
1795+
)
1796+
1797+
1798+
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
1799+
class UnetOpenVINOConfig(UNetOnnxConfig):
1800+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
1801+
1802+
17861803
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
17871804
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
17881805
DUMMY_INPUT_GENERATOR_CLASSES = (

tests/openvino/test_diffusion.py

+37
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
144144

145145
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
146146

147+
# test on inputs nondivisible on 64
148+
height, width, batch_size = 96, 96, 1
149+
150+
for output_type in ["latent", "np", "pt"]:
151+
inputs["output_type"] = output_type
152+
153+
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
154+
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
155+
156+
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
157+
147158
@parameterized.expand(CALLBACK_SUPPORT_ARCHITECTURES)
148159
@require_diffusers
149160
def test_callback(self, model_arch: str):
@@ -541,6 +552,20 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
541552

542553
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
543554

555+
# test generation when input resolution nondevisible on 64
556+
height, width, batch_size = 96, 96, 1
557+
558+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
559+
560+
for output_type in ["latent", "np", "pt"]:
561+
print(output_type)
562+
inputs["output_type"] = output_type
563+
564+
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
565+
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
566+
567+
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
568+
544569
@parameterized.expand(SUPPORTED_ARCHITECTURES)
545570
@require_diffusers
546571
def test_image_reproducibility(self, model_arch: str):
@@ -777,6 +802,18 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
777802

778803
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
779804

805+
# test generation when input resolution nondevisible on 64
806+
height, width, batch_size = 96, 96, 1
807+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
808+
809+
for output_type in ["latent", "np", "pt"]:
810+
inputs["output_type"] = output_type
811+
812+
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
813+
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
814+
815+
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
816+
780817
@parameterized.expand(SUPPORTED_ARCHITECTURES)
781818
@require_diffusers
782819
def test_image_reproducibility(self, model_arch: str):

0 commit comments

Comments
 (0)