Skip to content

Commit 4f1deda

Browse files
committed
fix sd3 tests
1 parent 233246a commit 4f1deda

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tests/openvino/test_diffusion.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -415,14 +415,16 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
415415

416416
TASK = "image-to-image"
417417

418-
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"):
418+
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil", model_type=None):
419419
inputs = _generate_prompts(batch_size=batch_size)
420420

421421
inputs["image"] = _generate_images(
422422
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
423423
)
424-
inputs["height"] = height
425-
inputs["width"] = width
424+
425+
if "flux" in model_type:
426+
inputs["height"] = height
427+
inputs["width"] = width
426428

427429
inputs["strength"] = 0.75
428430

@@ -452,15 +454,15 @@ def test_num_images_per_prompt(self, model_arch: str):
452454
for height in [64, 128]:
453455
for width in [64, 128]:
454456
for num_images_per_prompt in [1, 3]:
455-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
457+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
456458
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
457459
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
458460

459461
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
460462
@require_diffusers
461463
def test_callback(self, model_arch: str):
462464
height, width, batch_size = 32, 64, 1
463-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
465+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
464466

465467
class Callback:
466468
def __init__(self):
@@ -491,7 +493,7 @@ def test_shape(self, model_arch: str):
491493
height, width, batch_size = 128, 64, 1
492494

493495
for input_type in ["pil", "np", "pt"]:
494-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)
496+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type, model_type=model_arch)
495497

496498
for output_type in ["pil", "np", "pt", "latent"]:
497499
inputs["output_type"] = output_type
@@ -528,7 +530,7 @@ def test_shape(self, model_arch: str):
528530
@require_diffusers
529531
def test_compare_to_diffusers_pipeline(self, model_arch: str):
530532
height, width, batch_size = 128, 128, 1
531-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
533+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
532534

533535
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
534536
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
@@ -570,7 +572,7 @@ def test_safety_checker(self, model_arch: str):
570572
self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker)
571573

572574
height, width, batch_size = 32, 64, 1
573-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
575+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
574576

575577
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED))
576578
diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED))

0 commit comments

Comments
 (0)