Skip to content

Commit 4be7cd2

Browse files
committed
add tests
1 parent 58ada07 commit 4be7cd2

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

tests/openvino/test_diffusion.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -667,13 +667,14 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
667667
if is_transformers_version(">=", "4.40.0"):
668668
SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
669669
SUPPORTED_ARCHITECTURES.append("flux")
670+
SUPPORTED_ARCHITECTURES.append("flux-fill")
670671

671672
AUTOMODEL_CLASS = AutoPipelineForInpainting
672673
OVMODEL_CLASS = OVPipelineForInpainting
673674

674675
TASK = "inpainting"
675676

676-
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"):
677+
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil", model_arch=""):
677678
inputs = _generate_prompts(batch_size=batch_size)
678679

679680
inputs["image"] = _generate_images(
@@ -683,7 +684,8 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
683684
height=height, width=width, batch_size=batch_size, channel=1, input_type=input_type
684685
)
685686

686-
inputs["strength"] = 0.75
687+
if model_arch != "flux-fill":
688+
inputs["strength"] = 0.75
687689
inputs["height"] = height
688690
inputs["width"] = width
689691

@@ -699,7 +701,12 @@ def test_load_vanilla_model_which_is_not_supported(self):
699701
@parameterized.expand(SUPPORTED_ARCHITECTURES)
700702
@require_diffusers
701703
def test_ov_pipeline_class_dispatch(self, model_arch: str):
702-
auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
704+
if model_arch != "flux-fill":
705+
auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
706+
else:
707+
from diffusers import FluxFillPipeline
708+
709+
auto_pipeline = FluxFillPipeline.from_pretrained(MODEL_NAMES[model_arch])
703710
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
704711

705712
self.assertEqual(ov_pipeline.auto_model_class, auto_pipeline.__class__)
@@ -713,7 +720,9 @@ def test_num_images_per_prompt(self, model_arch: str):
713720
for height in [64, 128]:
714721
for width in [64, 128]:
715722
for num_images_per_prompt in [1, 3]:
716-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
723+
inputs = self.generate_inputs(
724+
height=height, width=width, batch_size=batch_size, model_arch=model_arch
725+
)
717726
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
718727
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
719728

@@ -752,7 +761,9 @@ def test_shape(self, model_arch: str):
752761
height, width, batch_size = 128, 64, 1
753762

754763
for input_type in ["pil", "np", "pt"]:
755-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)
764+
inputs = self.generate_inputs(
765+
height=height, width=width, batch_size=batch_size, input_type=input_type, model_arch=model_arch
766+
)
756767

757768
for output_type in ["pil", "np", "pt", "latent"]:
758769
inputs["output_type"] = output_type
@@ -764,7 +775,7 @@ def test_shape(self, model_arch: str):
764775
elif output_type == "pt":
765776
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
766777
else:
767-
if model_arch != "flux":
778+
if not model_arch.startswith("flux"):
768779
out_channels = (
769780
pipeline.unet.config.out_channels
770781
if pipeline.unet is not None
@@ -782,17 +793,26 @@ def test_shape(self, model_arch: str):
782793
else:
783794
packed_height = height // pipeline.vae_scale_factor // 2
784795
packed_width = width // pipeline.vae_scale_factor // 2
785-
channels = pipeline.transformer.config.in_channels
796+
channels = (
797+
pipeline.transformer.config.in_channels
798+
if model_arch != "flux-fill"
799+
else pipeline.transformer.out_channels
800+
)
786801
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
787802

788803
@parameterized.expand(SUPPORTED_ARCHITECTURES)
789804
@require_diffusers
790805
def test_compare_to_diffusers_pipeline(self, model_arch: str):
791806
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
792-
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
807+
if model_arch != "flux-fill":
808+
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
809+
else:
810+
from diffusers import FluxFillPipeline
811+
812+
diffusers_pipeline = FluxFillPipeline.from_pretrained(MODEL_NAMES[model_arch])
793813

794814
height, width, batch_size = 64, 64, 1
795-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
815+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_arch=model_arch)
796816

797817
for output_type in ["latent", "np", "pt"]:
798818
inputs["output_type"] = output_type
@@ -804,7 +824,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
804824

805825
# test generation when input resolution nondevisible on 64
806826
height, width, batch_size = 96, 96, 1
807-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
827+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_arch=model_arch)
808828

809829
for output_type in ["latent", "np", "pt"]:
810830
inputs["output_type"] = output_type
@@ -820,7 +840,7 @@ def test_image_reproducibility(self, model_arch: str):
820840
pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
821841

822842
height, width, batch_size = 64, 64, 1
823-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
843+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_arch=model_arch)
824844

825845
for generator_framework in ["np", "pt"]:
826846
ov_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED))

tests/openvino/utils_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
6767
"flaubert": "hf-internal-testing/tiny-random-flaubert",
6868
"flux": "katuni4ka/tiny-random-flux",
69+
"flux-fill": "katuni4ka/tiny-random-flux-fill",
6970
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
7071
"gpt2": "hf-internal-testing/tiny-random-gpt2",
7172
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",

0 commit comments

Comments
 (0)