@@ -415,14 +415,16 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
415
415
416
416
TASK = "image-to-image"
417
417
418
- def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" ):
418
+ def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" , model_type = None ):
419
419
inputs = _generate_prompts (batch_size = batch_size )
420
420
421
421
inputs ["image" ] = _generate_images (
422
422
height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
423
423
)
424
- inputs ["height" ] = height
425
- inputs ["width" ] = width
424
+
425
+ if "flux" in model_type :
426
+ inputs ["height" ] = height
427
+ inputs ["width" ] = width
426
428
427
429
inputs ["strength" ] = 0.75
428
430
@@ -452,15 +454,15 @@ def test_num_images_per_prompt(self, model_arch: str):
452
454
for height in [64 , 128 ]:
453
455
for width in [64 , 128 ]:
454
456
for num_images_per_prompt in [1 , 3 ]:
455
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
457
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
456
458
outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
457
459
self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
458
460
459
461
@parameterized .expand (["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ])
460
462
@require_diffusers
461
463
def test_callback (self , model_arch : str ):
462
464
height , width , batch_size = 32 , 64 , 1
463
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
465
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
464
466
465
467
class Callback :
466
468
def __init__ (self ):
@@ -491,7 +493,7 @@ def test_shape(self, model_arch: str):
491
493
height , width , batch_size = 128 , 64 , 1
492
494
493
495
for input_type in ["pil" , "np" , "pt" ]:
494
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type )
496
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type , model_type = model_arch )
495
497
496
498
for output_type in ["pil" , "np" , "pt" , "latent" ]:
497
499
inputs ["output_type" ] = output_type
@@ -528,7 +530,7 @@ def test_shape(self, model_arch: str):
528
530
@require_diffusers
529
531
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
530
532
height , width , batch_size = 128 , 128 , 1
531
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
533
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
532
534
533
535
diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
534
536
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
@@ -570,7 +572,7 @@ def test_safety_checker(self, model_arch: str):
570
572
self .assertIsInstance (ov_pipeline .safety_checker , StableDiffusionSafetyChecker )
571
573
572
574
height , width , batch_size = 32 , 64 , 1
573
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
575
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
574
576
575
577
ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED ))
576
578
diffusers_output = pipeline (** inputs , generator = get_generator ("pt" , SEED ))
0 commit comments