Skip to content

Commit b38f191

Browse files
authored
enable t5 in SD3 pipe testing and flux img2img and inpaint (#1016)
* enable t5 in SD3 pipe testing and flux img2img and inpaint * use object as non supported class * fix sd3 tests
1 parent ee96c82 commit b38f191

File tree

5 files changed

+140
-52
lines changed

5 files changed

+140
-52
lines changed

optimum/intel/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
"OVLatentConsistencyModelPipeline",
107107
"OVLatentConsistencyModelImg2ImgPipeline",
108108
"OVFluxPipeline",
109+
"OVFluxImg2ImgPipeline",
110+
"OVFluxInpaintPipeline",
109111
"OVPipelineForImage2Image",
110112
"OVPipelineForText2Image",
111113
"OVPipelineForInpainting",
@@ -126,6 +128,8 @@
126128
"OVLatentConsistencyModelPipeline",
127129
"OVLatentConsistencyModelImg2ImgPipeline",
128130
"OVFluxPipeline",
131+
"OVFluxImg2ImgPipeline",
132+
"OVFluxInpaintPipeline",
129133
"OVPipelineForImage2Image",
130134
"OVPipelineForText2Image",
131135
"OVPipelineForInpainting",

optimum/intel/openvino/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
if is_diffusers_available():
8383
from .modeling_diffusion import (
8484
OVDiffusionPipeline,
85+
OVFluxImg2ImgPipeline,
86+
OVFluxInpaintPipeline,
8587
OVFluxPipeline,
8688
OVLatentConsistencyModelImg2ImgPipeline,
8789
OVLatentConsistencyModelPipeline,

optimum/intel/openvino/modeling_diffusion.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,20 @@
8686
if is_diffusers_version(">=", "0.29.0"):
8787
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
8888
else:
89-
StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
89+
StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = object, object
9090

9191
if is_diffusers_version(">=", "0.30.0"):
9292
from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline
9393
else:
94-
StableDiffusion3InpaintPipeline = StableDiffusionInpaintPipeline
95-
FluxPipeline = StableDiffusionPipeline
94+
StableDiffusion3InpaintPipeline = object
95+
FluxPipeline = object
96+
97+
98+
if is_diffusers_version(">=", "0.31.0"):
99+
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline
100+
else:
101+
FluxImg2ImgPipeline = object
102+
FluxInpaintPipeline = object
96103

97104

98105
DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer"
@@ -887,9 +894,6 @@ def compile(self):
887894
def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
888895
return cls.load_config(config_name_or_path, **kwargs)
889896

890-
def _save_config(self, save_directory):
891-
self.save_config(save_directory)
892-
893897
@property
894898
def components(self) -> Dict[str, Any]:
895899
components = {
@@ -1447,6 +1451,18 @@ class OVFluxPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxPip
14471451
auto_model_class = FluxPipeline
14481452

14491453

1454+
class OVFluxImg2ImgPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxImg2ImgPipeline):
1455+
main_input_name = "prompt"
1456+
export_feature = "image-to-image"
1457+
auto_model_class = FluxImg2ImgPipeline
1458+
1459+
1460+
class OVFluxInpaintPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxInpaintPipeline):
1461+
main_input_name = "prompt"
1462+
export_feature = "inpainting"
1463+
auto_model_class = FluxInpaintPipeline
1464+
1465+
14501466
SUPPORTED_OV_PIPELINES = [
14511467
OVStableDiffusionPipeline,
14521468
OVStableDiffusionImg2ImgPipeline,
@@ -1510,6 +1526,10 @@ def _get_ov_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tru
15101526
OV_INPAINT_PIPELINES_MAPPING["stable-diffusion-3"] = OVStableDiffusion3InpaintPipeline
15111527
OV_TEXT2IMAGE_PIPELINES_MAPPING["flux"] = OVFluxPipeline
15121528

1529+
if is_diffusers_version(">=", "0.31.0"):
1530+
SUPPORTED_OV_PIPELINES.extend([OVFluxImg2ImgPipeline, OVFluxInpaintPipeline])
1531+
OV_INPAINT_PIPELINES_MAPPING["flux"] = OVFluxInpaintPipeline
1532+
OV_IMAGE2IMAGE_PIPELINES_MAPPING["flux"] = OVFluxImg2ImgPipeline
15131533

15141534
SUPPORTED_OV_PIPELINES_MAPPINGS = [
15151535
OV_TEXT2IMAGE_PIPELINES_MAPPING,

optimum/intel/utils/dummy_openvino_and_diffusers_objects.py

+22
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,25 @@ def __init__(self, *args, **kwargs):
189189
@classmethod
190190
def from_pretrained(cls, *args, **kwargs):
191191
requires_backends(cls, ["openvino", "diffusers"])
192+
193+
194+
class OVFluxImg2ImgPipeline(metaclass=DummyObject):
195+
_backends = ["openvino", "diffusers"]
196+
197+
def __init__(self, *args, **kwargs):
198+
requires_backends(self, ["openvino", "diffusers"])
199+
200+
@classmethod
201+
def from_pretrained(cls, *args, **kwargs):
202+
requires_backends(cls, ["openvino", "diffusers"])
203+
204+
205+
class OVFluxInpaintPipeline(metaclass=DummyObject):
206+
_backends = ["openvino", "diffusers"]
207+
208+
def __init__(self, *args, **kwargs):
209+
requires_backends(self, ["openvino", "diffusers"])
210+
211+
@classmethod
212+
def from_pretrained(cls, *args, **kwargs):
213+
requires_backends(cls, ["openvino", "diffusers"])

tests/openvino/test_diffusion.py

+86-46
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import unittest
1617
from pathlib import Path
1718

@@ -134,8 +135,8 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
134135
height, width, batch_size = 128, 128, 1
135136
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
136137

137-
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
138-
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
138+
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
139+
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
139140

140141
for output_type in ["latent", "np", "pt"]:
141142
inputs["output_type"] = output_type
@@ -330,6 +331,15 @@ def test_load_and_save_pipeline_with_safety_checker(self):
330331
]:
331332
subdir_path = Path(tmpdirname) / subdir
332333
self.assertTrue(subdir_path.is_dir())
334+
# check that config contains original model classes
335+
pipeline_config = Path(tmpdirname) / "model_index.json"
336+
self.assertTrue(pipeline_config.exists())
337+
with pipeline_config.open("r") as f:
338+
config = json.load(f)
339+
for key in ["unet", "vae", "text_encoder"]:
340+
model_lib, model_class = config[key]
341+
self.assertTrue(model_lib in ["diffusers", "transformers"])
342+
self.assertFalse(model_class.startswith("OV"))
333343
loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname)
334344
self.assertTrue(loaded_pipeline.safety_checker is not None)
335345
self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker)
@@ -398,19 +408,24 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
398408
SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"]
399409
if is_transformers_version(">=", "4.40.0"):
400410
SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
411+
SUPPORTED_ARCHITECTURES.append("flux")
401412

402413
AUTOMODEL_CLASS = AutoPipelineForImage2Image
403414
OVMODEL_CLASS = OVPipelineForImage2Image
404415

405416
TASK = "image-to-image"
406417

407-
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"):
418+
def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil", model_type=None):
408419
inputs = _generate_prompts(batch_size=batch_size)
409420

410421
inputs["image"] = _generate_images(
411422
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
412423
)
413424

425+
if "flux" == model_type:
426+
inputs["height"] = height
427+
inputs["width"] = width
428+
414429
inputs["strength"] = 0.75
415430

416431
return inputs
@@ -439,15 +454,17 @@ def test_num_images_per_prompt(self, model_arch: str):
439454
for height in [64, 128]:
440455
for width in [64, 128]:
441456
for num_images_per_prompt in [1, 3]:
442-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
457+
inputs = self.generate_inputs(
458+
height=height, width=width, batch_size=batch_size, model_type=model_arch
459+
)
443460
outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images
444461
self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3))
445462

446463
@parameterized.expand(["stable-diffusion", "stable-diffusion-xl", "latent-consistency"])
447464
@require_diffusers
448465
def test_callback(self, model_arch: str):
449466
height, width, batch_size = 32, 64, 1
450-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
467+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
451468

452469
class Callback:
453470
def __init__(self):
@@ -478,7 +495,9 @@ def test_shape(self, model_arch: str):
478495
height, width, batch_size = 128, 64, 1
479496

480497
for input_type in ["pil", "np", "pt"]:
481-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)
498+
inputs = self.generate_inputs(
499+
height=height, width=width, batch_size=batch_size, input_type=input_type, model_type=model_arch
500+
)
482501

483502
for output_type in ["pil", "np", "pt", "latent"]:
484503
inputs["output_type"] = output_type
@@ -490,29 +509,35 @@ def test_shape(self, model_arch: str):
490509
elif output_type == "pt":
491510
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
492511
else:
493-
out_channels = (
494-
pipeline.unet.config.out_channels
495-
if pipeline.unet is not None
496-
else pipeline.transformer.config.out_channels
497-
)
498-
self.assertEqual(
499-
outputs.shape,
500-
(
501-
batch_size,
502-
out_channels,
503-
height // pipeline.vae_scale_factor,
504-
width // pipeline.vae_scale_factor,
505-
),
506-
)
512+
if model_arch != "flux":
513+
out_channels = (
514+
pipeline.unet.config.out_channels
515+
if pipeline.unet is not None
516+
else pipeline.transformer.config.out_channels
517+
)
518+
self.assertEqual(
519+
outputs.shape,
520+
(
521+
batch_size,
522+
out_channels,
523+
height // pipeline.vae_scale_factor,
524+
width // pipeline.vae_scale_factor,
525+
),
526+
)
527+
else:
528+
packed_height = height // pipeline.vae_scale_factor
529+
packed_width = width // pipeline.vae_scale_factor
530+
channels = pipeline.transformer.config.in_channels
531+
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
507532

508533
@parameterized.expand(SUPPORTED_ARCHITECTURES)
509534
@require_diffusers
510535
def test_compare_to_diffusers_pipeline(self, model_arch: str):
511536
height, width, batch_size = 128, 128, 1
512-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
537+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
513538

514-
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
515-
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None)
539+
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
540+
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
516541

517542
for output_type in ["latent", "np", "pt"]:
518543
print(output_type)
@@ -529,7 +554,7 @@ def test_image_reproducibility(self, model_arch: str):
529554
pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])
530555

531556
height, width, batch_size = 64, 64, 1
532-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
557+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
533558

534559
for generator_framework in ["np", "pt"]:
535560
ov_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED))
@@ -551,7 +576,7 @@ def test_safety_checker(self, model_arch: str):
551576
self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker)
552577

553578
height, width, batch_size = 32, 64, 1
554-
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
579+
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch)
555580

556581
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED))
557582
diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED))
@@ -586,9 +611,13 @@ def test_height_width_properties(self, model_arch: str):
586611

587612
self.assertFalse(ov_pipeline.is_dynamic)
588613
expected_batch = batch_size * num_images_per_prompt
589-
if ov_pipeline.unet is None or "timestep_cond" not in {
590-
inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
591-
}:
614+
if (
615+
ov_pipeline.unet is not None
616+
and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}
617+
) or (
618+
ov_pipeline.transformer is not None
619+
and "txt_ids" not in {inputs.get_any_name() for inputs in ov_pipeline.transformer.model.inputs}
620+
):
592621
expected_batch *= 2
593622
self.assertEqual(ov_pipeline.batch_size, expected_batch)
594623
self.assertEqual(ov_pipeline.height, height)
@@ -604,7 +633,7 @@ def test_textual_inversion(self):
604633
model_id = "runwayml/stable-diffusion-v1-5"
605634
ti_id = "sd-concepts-library/cat-toy"
606635

607-
inputs = self.generate_inputs()
636+
inputs = self.generate_inputs(model_type="stable-diffusion")
608637
inputs["prompt"] = "A <cat-toy> backpack"
609638

610639
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
@@ -624,6 +653,7 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
624653

625654
if is_transformers_version(">=", "4.40.0"):
626655
SUPPORTED_ARCHITECTURES.append("stable-diffusion-3")
656+
SUPPORTED_ARCHITECTURES.append("flux")
627657

628658
AUTOMODEL_CLASS = AutoPipelineForInpainting
629659
OVMODEL_CLASS = OVPipelineForInpainting
@@ -721,20 +751,26 @@ def test_shape(self, model_arch: str):
721751
elif output_type == "pt":
722752
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
723753
else:
724-
out_channels = (
725-
pipeline.unet.config.out_channels
726-
if pipeline.unet is not None
727-
else pipeline.transformer.config.out_channels
728-
)
729-
self.assertEqual(
730-
outputs.shape,
731-
(
732-
batch_size,
733-
out_channels,
734-
height // pipeline.vae_scale_factor,
735-
width // pipeline.vae_scale_factor,
736-
),
737-
)
754+
if model_arch != "flux":
755+
out_channels = (
756+
pipeline.unet.config.out_channels
757+
if pipeline.unet is not None
758+
else pipeline.transformer.config.out_channels
759+
)
760+
self.assertEqual(
761+
outputs.shape,
762+
(
763+
batch_size,
764+
out_channels,
765+
height // pipeline.vae_scale_factor,
766+
width // pipeline.vae_scale_factor,
767+
),
768+
)
769+
else:
770+
packed_height = height // pipeline.vae_scale_factor
771+
packed_width = width // pipeline.vae_scale_factor
772+
channels = pipeline.transformer.config.in_channels
773+
self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels))
738774

739775
@parameterized.expand(SUPPORTED_ARCHITECTURES)
740776
@require_diffusers
@@ -816,9 +852,13 @@ def test_height_width_properties(self, model_arch: str):
816852

817853
self.assertFalse(ov_pipeline.is_dynamic)
818854
expected_batch = batch_size * num_images_per_prompt
819-
if ov_pipeline.unet is None or "timestep_cond" not in {
820-
inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
821-
}:
855+
if (
856+
ov_pipeline.unet is not None
857+
and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}
858+
) or (
859+
ov_pipeline.transformer is not None
860+
and "txt_ids" not in {inputs.get_any_name() for inputs in ov_pipeline.transformer.model.inputs}
861+
):
822862
expected_batch *= 2
823863
self.assertEqual(
824864
ov_pipeline.batch_size,

0 commit comments

Comments
 (0)