12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import json
15
16
import unittest
16
17
from pathlib import Path
17
18
@@ -134,8 +135,8 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
134
135
height , width , batch_size = 128 , 128 , 1
135
136
inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
136
137
137
- ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
138
- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
138
+ ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
139
+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
139
140
140
141
for output_type in ["latent" , "np" , "pt" ]:
141
142
inputs ["output_type" ] = output_type
@@ -330,6 +331,15 @@ def test_load_and_save_pipeline_with_safety_checker(self):
330
331
]:
331
332
subdir_path = Path (tmpdirname ) / subdir
332
333
self .assertTrue (subdir_path .is_dir ())
334
+ # check that config contains original model classes
335
+ pipeline_config = Path (tmpdirname ) / "model_index.json"
336
+ self .assertTrue (pipeline_config .exists ())
337
+ with pipeline_config .open ("r" ) as f :
338
+ config = json .load (f )
339
+ for key in ["unet" , "vae" , "text_encoder" ]:
340
+ model_lib , model_class = config [key ]
341
+ self .assertTrue (model_lib in ["diffusers" , "transformers" ])
342
+ self .assertFalse (model_class .startswith ("OV" ))
333
343
loaded_pipeline = self .OVMODEL_CLASS .from_pretrained (tmpdirname )
334
344
self .assertTrue (loaded_pipeline .safety_checker is not None )
335
345
self .assertIsInstance (loaded_pipeline .safety_checker , StableDiffusionSafetyChecker )
@@ -398,19 +408,24 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
398
408
SUPPORTED_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
399
409
if is_transformers_version (">=" , "4.40.0" ):
400
410
SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
411
+ SUPPORTED_ARCHITECTURES .append ("flux" )
401
412
402
413
AUTOMODEL_CLASS = AutoPipelineForImage2Image
403
414
OVMODEL_CLASS = OVPipelineForImage2Image
404
415
405
416
TASK = "image-to-image"
406
417
407
- def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" ):
418
+ def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" , model_type = None ):
408
419
inputs = _generate_prompts (batch_size = batch_size )
409
420
410
421
inputs ["image" ] = _generate_images (
411
422
height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
412
423
)
413
424
425
+ if "flux" == model_type :
426
+ inputs ["height" ] = height
427
+ inputs ["width" ] = width
428
+
414
429
inputs ["strength" ] = 0.75
415
430
416
431
return inputs
@@ -439,15 +454,17 @@ def test_num_images_per_prompt(self, model_arch: str):
439
454
for height in [64 , 128 ]:
440
455
for width in [64 , 128 ]:
441
456
for num_images_per_prompt in [1 , 3 ]:
442
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
457
+ inputs = self .generate_inputs (
458
+ height = height , width = width , batch_size = batch_size , model_type = model_arch
459
+ )
443
460
outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
444
461
self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
445
462
446
463
@parameterized .expand (["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ])
447
464
@require_diffusers
448
465
def test_callback (self , model_arch : str ):
449
466
height , width , batch_size = 32 , 64 , 1
450
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
467
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
451
468
452
469
class Callback :
453
470
def __init__ (self ):
@@ -478,7 +495,9 @@ def test_shape(self, model_arch: str):
478
495
height , width , batch_size = 128 , 64 , 1
479
496
480
497
for input_type in ["pil" , "np" , "pt" ]:
481
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type )
498
+ inputs = self .generate_inputs (
499
+ height = height , width = width , batch_size = batch_size , input_type = input_type , model_type = model_arch
500
+ )
482
501
483
502
for output_type in ["pil" , "np" , "pt" , "latent" ]:
484
503
inputs ["output_type" ] = output_type
@@ -490,29 +509,35 @@ def test_shape(self, model_arch: str):
490
509
elif output_type == "pt" :
491
510
self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
492
511
else :
493
- out_channels = (
494
- pipeline .unet .config .out_channels
495
- if pipeline .unet is not None
496
- else pipeline .transformer .config .out_channels
497
- )
498
- self .assertEqual (
499
- outputs .shape ,
500
- (
501
- batch_size ,
502
- out_channels ,
503
- height // pipeline .vae_scale_factor ,
504
- width // pipeline .vae_scale_factor ,
505
- ),
506
- )
512
+ if model_arch != "flux" :
513
+ out_channels = (
514
+ pipeline .unet .config .out_channels
515
+ if pipeline .unet is not None
516
+ else pipeline .transformer .config .out_channels
517
+ )
518
+ self .assertEqual (
519
+ outputs .shape ,
520
+ (
521
+ batch_size ,
522
+ out_channels ,
523
+ height // pipeline .vae_scale_factor ,
524
+ width // pipeline .vae_scale_factor ,
525
+ ),
526
+ )
527
+ else :
528
+ packed_height = height // pipeline .vae_scale_factor
529
+ packed_width = width // pipeline .vae_scale_factor
530
+ channels = pipeline .transformer .config .in_channels
531
+ self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
507
532
508
533
@parameterized .expand (SUPPORTED_ARCHITECTURES )
509
534
@require_diffusers
510
535
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
511
536
height , width , batch_size = 128 , 128 , 1
512
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
537
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
513
538
514
- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
515
- ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
539
+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
540
+ ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
516
541
517
542
for output_type in ["latent" , "np" , "pt" ]:
518
543
print (output_type )
@@ -529,7 +554,7 @@ def test_image_reproducibility(self, model_arch: str):
529
554
pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
530
555
531
556
height , width , batch_size = 64 , 64 , 1
532
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
557
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
533
558
534
559
for generator_framework in ["np" , "pt" ]:
535
560
ov_outputs_1 = pipeline (** inputs , generator = get_generator (generator_framework , SEED ))
@@ -551,7 +576,7 @@ def test_safety_checker(self, model_arch: str):
551
576
self .assertIsInstance (ov_pipeline .safety_checker , StableDiffusionSafetyChecker )
552
577
553
578
height , width , batch_size = 32 , 64 , 1
554
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
579
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
555
580
556
581
ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED ))
557
582
diffusers_output = pipeline (** inputs , generator = get_generator ("pt" , SEED ))
@@ -586,9 +611,13 @@ def test_height_width_properties(self, model_arch: str):
586
611
587
612
self .assertFalse (ov_pipeline .is_dynamic )
588
613
expected_batch = batch_size * num_images_per_prompt
589
- if ov_pipeline .unet is None or "timestep_cond" not in {
590
- inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs
591
- }:
614
+ if (
615
+ ov_pipeline .unet is not None
616
+ and "timestep_cond" not in {inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs }
617
+ ) or (
618
+ ov_pipeline .transformer is not None
619
+ and "txt_ids" not in {inputs .get_any_name () for inputs in ov_pipeline .transformer .model .inputs }
620
+ ):
592
621
expected_batch *= 2
593
622
self .assertEqual (ov_pipeline .batch_size , expected_batch )
594
623
self .assertEqual (ov_pipeline .height , height )
@@ -604,7 +633,7 @@ def test_textual_inversion(self):
604
633
model_id = "runwayml/stable-diffusion-v1-5"
605
634
ti_id = "sd-concepts-library/cat-toy"
606
635
607
- inputs = self .generate_inputs ()
636
+ inputs = self .generate_inputs (model_type = "stable-diffusion" )
608
637
inputs ["prompt" ] = "A <cat-toy> backpack"
609
638
610
639
diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (model_id , safety_checker = None )
@@ -624,6 +653,7 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
624
653
625
654
if is_transformers_version (">=" , "4.40.0" ):
626
655
SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
656
+ SUPPORTED_ARCHITECTURES .append ("flux" )
627
657
628
658
AUTOMODEL_CLASS = AutoPipelineForInpainting
629
659
OVMODEL_CLASS = OVPipelineForInpainting
@@ -721,20 +751,26 @@ def test_shape(self, model_arch: str):
721
751
elif output_type == "pt" :
722
752
self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
723
753
else :
724
- out_channels = (
725
- pipeline .unet .config .out_channels
726
- if pipeline .unet is not None
727
- else pipeline .transformer .config .out_channels
728
- )
729
- self .assertEqual (
730
- outputs .shape ,
731
- (
732
- batch_size ,
733
- out_channels ,
734
- height // pipeline .vae_scale_factor ,
735
- width // pipeline .vae_scale_factor ,
736
- ),
737
- )
754
+ if model_arch != "flux" :
755
+ out_channels = (
756
+ pipeline .unet .config .out_channels
757
+ if pipeline .unet is not None
758
+ else pipeline .transformer .config .out_channels
759
+ )
760
+ self .assertEqual (
761
+ outputs .shape ,
762
+ (
763
+ batch_size ,
764
+ out_channels ,
765
+ height // pipeline .vae_scale_factor ,
766
+ width // pipeline .vae_scale_factor ,
767
+ ),
768
+ )
769
+ else :
770
+ packed_height = height // pipeline .vae_scale_factor
771
+ packed_width = width // pipeline .vae_scale_factor
772
+ channels = pipeline .transformer .config .in_channels
773
+ self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
738
774
739
775
@parameterized .expand (SUPPORTED_ARCHITECTURES )
740
776
@require_diffusers
@@ -816,9 +852,13 @@ def test_height_width_properties(self, model_arch: str):
816
852
817
853
self .assertFalse (ov_pipeline .is_dynamic )
818
854
expected_batch = batch_size * num_images_per_prompt
819
- if ov_pipeline .unet is None or "timestep_cond" not in {
820
- inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs
821
- }:
855
+ if (
856
+ ov_pipeline .unet is not None
857
+ and "timestep_cond" not in {inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs }
858
+ ) or (
859
+ ov_pipeline .transformer is not None
860
+ and "txt_ids" not in {inputs .get_any_name () for inputs in ov_pipeline .transformer .model .inputs }
861
+ ):
822
862
expected_batch *= 2
823
863
self .assertEqual (
824
864
ov_pipeline .batch_size ,
0 commit comments