12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import json
15
16
import unittest
16
17
from pathlib import Path
17
18
@@ -134,8 +135,8 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
134
135
height , width , batch_size = 128 , 128 , 1
135
136
inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
136
137
137
- ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
138
- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
138
+ ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
139
+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
139
140
140
141
for output_type in ["latent" , "np" , "pt" ]:
141
142
inputs ["output_type" ] = output_type
@@ -330,6 +331,15 @@ def test_load_and_save_pipeline_with_safety_checker(self):
330
331
]:
331
332
subdir_path = Path (tmpdirname ) / subdir
332
333
self .assertTrue (subdir_path .is_dir ())
334
+ # check that config contains original model classes
335
+ pipeline_config = Path (tmpdirname ) / "model_index.json"
336
+ self .assertTrue (pipeline_config .exists ())
337
+ with pipeline_config .open ("r" ) as f :
338
+ config = json .load (f )
339
+ for key in ["unet" , "vae" , "text_encoder" ]:
340
+ model_lib , model_class = config [key ]
341
+ self .assertTrue (model_lib in ["diffusers" , "transformers" ])
342
+ self .assertFalse (model_class .startswith ("OV" ))
333
343
loaded_pipeline = self .OVMODEL_CLASS .from_pretrained (tmpdirname )
334
344
self .assertTrue (loaded_pipeline .safety_checker is not None )
335
345
self .assertIsInstance (loaded_pipeline .safety_checker , StableDiffusionSafetyChecker )
@@ -398,6 +408,7 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
398
408
SUPPORTED_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
399
409
if is_transformers_version (">=" , "4.40.0" ):
400
410
SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
411
+ SUPPORTED_ARCHITECTURES .append ("flux" )
401
412
402
413
AUTOMODEL_CLASS = AutoPipelineForImage2Image
403
414
OVMODEL_CLASS = OVPipelineForImage2Image
@@ -410,6 +421,8 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
410
421
inputs ["image" ] = _generate_images (
411
422
height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
412
423
)
424
+ inputs ["height" ] = height
425
+ inputs ["width" ] = width
413
426
414
427
inputs ["strength" ] = 0.75
415
428
@@ -490,29 +503,35 @@ def test_shape(self, model_arch: str):
490
503
elif output_type == "pt" :
491
504
self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
492
505
else :
493
- out_channels = (
494
- pipeline .unet .config .out_channels
495
- if pipeline .unet is not None
496
- else pipeline .transformer .config .out_channels
497
- )
498
- self .assertEqual (
499
- outputs .shape ,
500
- (
501
- batch_size ,
502
- out_channels ,
503
- height // pipeline .vae_scale_factor ,
504
- width // pipeline .vae_scale_factor ,
505
- ),
506
- )
506
+ if model_arch != "flux" :
507
+ out_channels = (
508
+ pipeline .unet .config .out_channels
509
+ if pipeline .unet is not None
510
+ else pipeline .transformer .config .out_channels
511
+ )
512
+ self .assertEqual (
513
+ outputs .shape ,
514
+ (
515
+ batch_size ,
516
+ out_channels ,
517
+ height // pipeline .vae_scale_factor ,
518
+ width // pipeline .vae_scale_factor ,
519
+ ),
520
+ )
521
+ else :
522
+ packed_height = height // pipeline .vae_scale_factor
523
+ packed_width = width // pipeline .vae_scale_factor
524
+ channels = pipeline .transformer .config .in_channels
525
+ self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
507
526
508
527
@parameterized .expand (SUPPORTED_ARCHITECTURES )
509
528
@require_diffusers
510
529
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
511
530
height , width , batch_size = 128 , 128 , 1
512
531
inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
513
532
514
- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
515
- ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ], text_encoder_3 = None )
533
+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
534
+ ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
516
535
517
536
for output_type in ["latent" , "np" , "pt" ]:
518
537
print (output_type )
@@ -586,9 +605,13 @@ def test_height_width_properties(self, model_arch: str):
586
605
587
606
self .assertFalse (ov_pipeline .is_dynamic )
588
607
expected_batch = batch_size * num_images_per_prompt
589
- if ov_pipeline .unet is None or "timestep_cond" not in {
590
- inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs
591
- }:
608
+ if (
609
+ ov_pipeline .unet is not None
610
+ and "timestep_cond" not in {inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs }
611
+ ) or (
612
+ ov_pipeline .transformer is not None
613
+ and "txt_ids" not in {inputs .get_any_name () for inputs in ov_pipeline .transformer .model .inputs }
614
+ ):
592
615
expected_batch *= 2
593
616
self .assertEqual (ov_pipeline .batch_size , expected_batch )
594
617
self .assertEqual (ov_pipeline .height , height )
@@ -624,6 +647,7 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
624
647
625
648
if is_transformers_version (">=" , "4.40.0" ):
626
649
SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
650
+ SUPPORTED_ARCHITECTURES .append ("flux" )
627
651
628
652
AUTOMODEL_CLASS = AutoPipelineForInpainting
629
653
OVMODEL_CLASS = OVPipelineForInpainting
@@ -721,20 +745,26 @@ def test_shape(self, model_arch: str):
721
745
elif output_type == "pt" :
722
746
self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
723
747
else :
724
- out_channels = (
725
- pipeline .unet .config .out_channels
726
- if pipeline .unet is not None
727
- else pipeline .transformer .config .out_channels
728
- )
729
- self .assertEqual (
730
- outputs .shape ,
731
- (
732
- batch_size ,
733
- out_channels ,
734
- height // pipeline .vae_scale_factor ,
735
- width // pipeline .vae_scale_factor ,
736
- ),
737
- )
748
+ if model_arch != "flux" :
749
+ out_channels = (
750
+ pipeline .unet .config .out_channels
751
+ if pipeline .unet is not None
752
+ else pipeline .transformer .config .out_channels
753
+ )
754
+ self .assertEqual (
755
+ outputs .shape ,
756
+ (
757
+ batch_size ,
758
+ out_channels ,
759
+ height // pipeline .vae_scale_factor ,
760
+ width // pipeline .vae_scale_factor ,
761
+ ),
762
+ )
763
+ else :
764
+ packed_height = height // pipeline .vae_scale_factor
765
+ packed_width = width // pipeline .vae_scale_factor
766
+ channels = pipeline .transformer .config .in_channels
767
+ self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
738
768
739
769
@parameterized .expand (SUPPORTED_ARCHITECTURES )
740
770
@require_diffusers
@@ -816,9 +846,13 @@ def test_height_width_properties(self, model_arch: str):
816
846
817
847
self .assertFalse (ov_pipeline .is_dynamic )
818
848
expected_batch = batch_size * num_images_per_prompt
819
- if ov_pipeline .unet is None or "timestep_cond" not in {
820
- inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs
821
- }:
849
+ if (
850
+ ov_pipeline .unet is not None
851
+ and "timestep_cond" not in {inputs .get_any_name () for inputs in ov_pipeline .unet .model .inputs }
852
+ ) or (
853
+ ov_pipeline .transformer is not None
854
+ and "txt_ids" not in {inputs .get_any_name () for inputs in ov_pipeline .transformer .model .inputs }
855
+ ):
822
856
expected_batch *= 2
823
857
self .assertEqual (
824
858
ov_pipeline .batch_size ,
0 commit comments