Skip to content

Commit 014a840

Browse files
authored
fix timestep export shapes in sd3 and flux and tests with diffusers 0.32 (#1094)
1 parent 8a56275 commit 014a840

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

optimum/exporters/openvino/model_configs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1806,7 +1806,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18061806

18071807
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
18081808
@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers")
1809-
class UnetOpenVINOConfig(UNetOnnxConfig):
1809+
class UNetOpenVINOConfig(UNetOnnxConfig):
18101810
DUMMY_INPUT_GENERATOR_CLASSES = (
18111811
DummyUnetVisionInputGenerator,
18121812
DummyUnetTimestepInputGenerator,
@@ -1821,10 +1821,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
18211821

18221822
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
18231823
@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
1824-
class SD3TransformerOpenVINOConfig(UNetOnnxConfig):
1824+
class SD3TransformerOpenVINOConfig(UNetOpenVINOConfig):
18251825
DUMMY_INPUT_GENERATOR_CLASSES = (
18261826
(DummyTransformerTimestpsInputGenerator,)
1827-
+ UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
1827+
+ UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES
18281828
+ (PooledProjectionsDummyInputGenerator,)
18291829
)
18301830
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(

tests/openvino/test_diffusion.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def test_shape(self, model_arch: str):
218218
),
219219
)
220220
else:
221-
packed_height = height // pipeline.vae_scale_factor
222-
packed_width = width // pipeline.vae_scale_factor
221+
packed_height = height // pipeline.vae_scale_factor // 2
222+
packed_width = width // pipeline.vae_scale_factor // 2
223223
channels = pipeline.transformer.config.in_channels
224224
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
225225

@@ -426,7 +426,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
426426
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
427427
)
428428

429-
if "flux" == model_type:
429+
if model_type in ["flux", "stable-diffusion-3"]:
430430
inputs["height"] = height
431431
inputs["width"] = width
432432

@@ -529,8 +529,8 @@ def test_shape(self, model_arch: str):
529529
),
530530
)
531531
else:
532-
packed_height = height // pipeline.vae_scale_factor
533-
packed_width = width // pipeline.vae_scale_factor
532+
packed_height = height // pipeline.vae_scale_factor // 2
533+
packed_width = width // pipeline.vae_scale_factor // 2
534534
channels = pipeline.transformer.config.in_channels
535535
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
536536

@@ -780,8 +780,8 @@ def test_shape(self, model_arch: str):
780780
),
781781
)
782782
else:
783-
packed_height = height // pipeline.vae_scale_factor
784-
packed_width = width // pipeline.vae_scale_factor
783+
packed_height = height // pipeline.vae_scale_factor // 2
784+
packed_width = width // pipeline.vae_scale_factor // 2
785785
channels = pipeline.transformer.config.in_channels
786786
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
787787

0 commit comments

Comments
 (0)