@@ -415,14 +415,16 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
415
415
416
416
TASK = "image-to-image"
417
417
418
- def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" ):
418
+ def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" , model_type = None ):
419
419
inputs = _generate_prompts (batch_size = batch_size )
420
420
421
421
inputs ["image" ] = _generate_images (
422
422
height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
423
423
)
424
- inputs ["height" ] = height
425
- inputs ["width" ] = width
424
+
425
+ if "flux" in model_type :
426
+ inputs ["height" ] = height
427
+ inputs ["width" ] = width
426
428
427
429
inputs ["strength" ] = 0.75
428
430
@@ -452,15 +454,17 @@ def test_num_images_per_prompt(self, model_arch: str):
452
454
for height in [64 , 128 ]:
453
455
for width in [64 , 128 ]:
454
456
for num_images_per_prompt in [1 , 3 ]:
455
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
457
+ inputs = self .generate_inputs (
458
+ height = height , width = width , batch_size = batch_size , model_type = model_arch
459
+ )
456
460
outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
457
461
self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
458
462
459
463
@parameterized .expand (["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ])
460
464
@require_diffusers
461
465
def test_callback (self , model_arch : str ):
462
466
height , width , batch_size = 32 , 64 , 1
463
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
467
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
464
468
465
469
class Callback :
466
470
def __init__ (self ):
@@ -491,7 +495,9 @@ def test_shape(self, model_arch: str):
491
495
height , width , batch_size = 128 , 64 , 1
492
496
493
497
for input_type in ["pil" , "np" , "pt" ]:
494
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type )
498
+ inputs = self .generate_inputs (
499
+ height = height , width = width , batch_size = batch_size , input_type = input_type , model_type = model_arch
500
+ )
495
501
496
502
for output_type in ["pil" , "np" , "pt" , "latent" ]:
497
503
inputs ["output_type" ] = output_type
@@ -528,7 +534,7 @@ def test_shape(self, model_arch: str):
528
534
@require_diffusers
529
535
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
530
536
height , width , batch_size = 128 , 128 , 1
531
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
537
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
532
538
533
539
diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
534
540
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
@@ -570,7 +576,7 @@ def test_safety_checker(self, model_arch: str):
570
576
self .assertIsInstance (ov_pipeline .safety_checker , StableDiffusionSafetyChecker )
571
577
572
578
height , width , batch_size = 32 , 64 , 1
573
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
579
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
574
580
575
581
ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED ))
576
582
diffusers_output = pipeline (** inputs , generator = get_generator ("pt" , SEED ))
0 commit comments