Skip to content

Commit e425d76

Browse files
committed
fix tests
1 parent 094f7c2 commit e425d76

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

optimum/intel/openvino/modeling_diffusion.py

+2
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ def _reshape_transformer(
642642
batch_size = -1
643643
else:
644644
batch_size *= num_images_per_prompt
645+
# The factor of 2 comes from the guidance scale > 1
646+
batch_size *= 2
645647

646648
height = height // self.vae_scale_factor if height > 0 else height
647649
width = width // self.vae_scale_factor if width > 0 else width

tests/openvino/test_diffusion.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
135135
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
136136
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
137137

138-
np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
138+
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
139139

140-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
140+
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
141141
@require_diffusers
142142
def test_callback(self, model_arch: str):
143143
height, width, batch_size = 64, 128, 1
@@ -184,9 +184,10 @@ def test_shape(self, model_arch: str):
184184
elif output_type == "pt":
185185
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
186186
else:
187+
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
187188
self.assertEqual(
188189
outputs.shape,
189-
(batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
190+
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
190191
)
191192

192193
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -229,6 +230,22 @@ def test_negative_prompt(self, model_arch: str):
229230
do_classifier_free_guidance=True,
230231
negative_prompt=negative_prompt,
231232
)
233+
elif model_arch == "stable-diffusion-3":
234+
(
235+
inputs["prompt_embeds"],
236+
inputs["negative_prompt_embeds"],
237+
inputs["pooled_prompt_embeds"],
238+
inputs["negative_pooled_prompt_embeds"],
239+
) = pipeline.encode_prompt(
240+
prompt=prompt,
241+
prompt_2=None,
242+
prompt_3=None,
243+
num_images_per_prompt=1,
244+
device=torch.device("cpu"),
245+
do_classifier_free_guidance=True,
246+
negative_prompt=negative_prompt,
247+
)
248+
232249
else:
233250
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt(
234251
prompt=prompt,
@@ -288,11 +305,12 @@ def test_height_width_properties(self, model_arch: str):
288305
)
289306

290307
self.assertFalse(ov_pipeline.is_dynamic)
308+
expected_batch = batch_size * num_images_per_prompt
309+
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
310+
expected_batch *= 2
291311
self.assertEqual(
292312
ov_pipeline.batch_size,
293-
batch_size
294-
* num_images_per_prompt
295-
* (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
313+
expected_batch,
296314
)
297315
self.assertEqual(ov_pipeline.height, height)
298316
self.assertEqual(ov_pipeline.width, width)
@@ -369,7 +387,7 @@ def test_num_images_per_prompt(self, model_arch: str):
369387
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
370388
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
371389

372-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
390+
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
373391
@require_diffusers
374392
def test_callback(self, model_arch: str):
375393
height, width, batch_size = 32, 64, 1
@@ -416,9 +434,10 @@ def test_shape(self, model_arch: str):
416434
elif output_type == "pt":
417435
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
418436
else:
437+
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
419438
self.assertEqual(
420439
outputs.shape,
421-
(batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
440+
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
422441
)
423442

424443
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -500,11 +519,12 @@ def test_height_width_properties(self, model_arch: str):
500519
)
501520

502521
self.assertFalse(ov_pipeline.is_dynamic)
522+
expected_batch = batch_size * num_images_per_prompt
523+
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
524+
expected_batch *= 2
503525
self.assertEqual(
504526
ov_pipeline.batch_size,
505-
batch_size
506-
* num_images_per_prompt
507-
* (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
527+
expected_batch
508528
)
509529
self.assertEqual(ov_pipeline.height, height)
510530
self.assertEqual(ov_pipeline.width, width)
@@ -586,7 +606,7 @@ def test_num_images_per_prompt(self, model_arch: str):
586606
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
587607
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
588608

589-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
609+
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl"])
590610
@require_diffusers
591611
def test_callback(self, model_arch: str):
592612
height, width, batch_size = 32, 64, 1
@@ -633,9 +653,10 @@ def test_shape(self, model_arch: str):
633653
elif output_type == "pt":
634654
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
635655
else:
656+
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
636657
self.assertEqual(
637658
outputs.shape,
638-
(batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
659+
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
639660
)
640661

641662
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -717,11 +738,12 @@ def test_height_width_properties(self, model_arch: str):
717738
)
718739

719740
self.assertFalse(ov_pipeline.is_dynamic)
741+
expected_batch = batch_size * num_images_per_prompt
742+
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
743+
expected_batch *= 2
720744
self.assertEqual(
721745
ov_pipeline.batch_size,
722-
batch_size
723-
* num_images_per_prompt
724-
* (2 if "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} else 1),
746+
expected_batch,
725747
)
726748
self.assertEqual(ov_pipeline.height, height)
727749
self.assertEqual(ov_pipeline.width, width)

0 commit comments

Comments
 (0)