@@ -667,13 +667,14 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
667
667
if is_transformers_version (">=" , "4.40.0" ):
668
668
SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
669
669
SUPPORTED_ARCHITECTURES .append ("flux" )
670
+ SUPPORTED_ARCHITECTURES .append ("flux-fill" )
670
671
671
672
AUTOMODEL_CLASS = AutoPipelineForInpainting
672
673
OVMODEL_CLASS = OVPipelineForInpainting
673
674
674
675
TASK = "inpainting"
675
676
676
- def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" ):
677
+ def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" , model_arch = "" ):
677
678
inputs = _generate_prompts (batch_size = batch_size )
678
679
679
680
inputs ["image" ] = _generate_images (
@@ -683,7 +684,8 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
683
684
height = height , width = width , batch_size = batch_size , channel = 1 , input_type = input_type
684
685
)
685
686
686
- inputs ["strength" ] = 0.75
687
+ if model_arch != "flux-fill" :
688
+ inputs ["strength" ] = 0.75
687
689
inputs ["height" ] = height
688
690
inputs ["width" ] = width
689
691
@@ -699,7 +701,12 @@ def test_load_vanilla_model_which_is_not_supported(self):
699
701
@parameterized .expand (SUPPORTED_ARCHITECTURES )
700
702
@require_diffusers
701
703
def test_ov_pipeline_class_dispatch (self , model_arch : str ):
702
- auto_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
704
+ if model_arch != "flux-fill" :
705
+ auto_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
706
+ else :
707
+ from diffusers import FluxFillPipeline
708
+
709
+ auto_pipeline = FluxFillPipeline .from_pretrained (MODEL_NAMES [model_arch ])
703
710
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
704
711
705
712
self .assertEqual (ov_pipeline .auto_model_class , auto_pipeline .__class__ )
@@ -713,7 +720,9 @@ def test_num_images_per_prompt(self, model_arch: str):
713
720
for height in [64 , 128 ]:
714
721
for width in [64 , 128 ]:
715
722
for num_images_per_prompt in [1 , 3 ]:
716
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
723
+ inputs = self .generate_inputs (
724
+ height = height , width = width , batch_size = batch_size , model_arch = model_arch
725
+ )
717
726
outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
718
727
self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
719
728
@@ -752,7 +761,9 @@ def test_shape(self, model_arch: str):
752
761
height , width , batch_size = 128 , 64 , 1
753
762
754
763
for input_type in ["pil" , "np" , "pt" ]:
755
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type )
764
+ inputs = self .generate_inputs (
765
+ height = height , width = width , batch_size = batch_size , input_type = input_type , model_arch = model_arch
766
+ )
756
767
757
768
for output_type in ["pil" , "np" , "pt" , "latent" ]:
758
769
inputs ["output_type" ] = output_type
@@ -764,7 +775,7 @@ def test_shape(self, model_arch: str):
764
775
elif output_type == "pt" :
765
776
self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
766
777
else :
767
- if model_arch != "flux" :
778
+ if not model_arch . startswith ( "flux" ) :
768
779
out_channels = (
769
780
pipeline .unet .config .out_channels
770
781
if pipeline .unet is not None
@@ -782,17 +793,26 @@ def test_shape(self, model_arch: str):
782
793
else :
783
794
packed_height = height // pipeline .vae_scale_factor // 2
784
795
packed_width = width // pipeline .vae_scale_factor // 2
785
- channels = pipeline .transformer .config .in_channels
796
+ channels = (
797
+ pipeline .transformer .config .in_channels
798
+ if model_arch != "flux-fill"
799
+ else pipeline .transformer .out_channels
800
+ )
786
801
self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
787
802
788
803
@parameterized .expand (SUPPORTED_ARCHITECTURES )
789
804
@require_diffusers
790
805
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
791
806
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
792
- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
807
+ if model_arch != "flux-fill" :
808
+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
809
+ else :
810
+ from diffusers import FluxFillPipeline
811
+
812
+ diffusers_pipeline = FluxFillPipeline .from_pretrained (MODEL_NAMES [model_arch ])
793
813
794
814
height , width , batch_size = 64 , 64 , 1
795
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
815
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
796
816
797
817
for output_type in ["latent" , "np" , "pt" ]:
798
818
inputs ["output_type" ] = output_type
@@ -804,7 +824,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
804
824
805
825
# test generation when input resolution nondevisible on 64
806
826
height , width , batch_size = 96 , 96 , 1
807
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
827
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
808
828
809
829
for output_type in ["latent" , "np" , "pt" ]:
810
830
inputs ["output_type" ] = output_type
@@ -820,7 +840,7 @@ def test_image_reproducibility(self, model_arch: str):
820
840
pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
821
841
822
842
height , width , batch_size = 64 , 64 , 1
823
- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
843
+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
824
844
825
845
for generator_framework in ["np" , "pt" ]:
826
846
ov_outputs_1 = pipeline (** inputs , generator = get_generator (generator_framework , SEED ))
0 commit comments