Skip to content

Commit 6f1ef82

Browse files
committed
add export tests
1 parent 3cf8894 commit 6f1ef82

File tree

5 files changed

+18
-3
lines changed

5 files changed

+18
-3
lines changed

optimum/exporters/openvino/convert.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,12 @@ def get_diffusion_models_for_export_ext(
919919

920920
if not is_sd3:
921921
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
922+
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
922923

924+
return None, models_for_export
925+
926+
927+
def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
923928
models_for_export = {}
924929

925930
# Text encoder
@@ -948,7 +953,7 @@ def get_diffusion_models_for_export_ext(
948953
exporter=exporter,
949954
library_name="diffusers",
950955
task="semantic-segmentation",
951-
model_type="transformer",
956+
model_type="sd3-transformer",
952957
)
953958
transformer_export_config = export_config_constructor(
954959
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
@@ -1015,4 +1020,4 @@ def get_diffusion_models_for_export_ext(
10151020
)
10161021
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
10171022

1018-
return None, models_for_export
1023+
return models_for_export

optimum/exporters/openvino/model_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
15501550
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
15511551

15521552

1553-
@register_in_tasks_manager("transformer", *["semantic-segmentation"], library_name="diffusers")
1553+
@register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers")
15541554
class TransformerOpenVINOConfig(UNetOnnxConfig):
15551555
DUMMY_INPUT_GENERATOR_CLASSES = UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + (
15561556
PooledProjectionsDummyInputGenerator,

optimum/intel/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100
"OVStableDiffusionXLPipeline",
101101
"OVStableDiffusionXLImg2ImgPipeline",
102102
"OVStableDiffusionXLInpaintPipeline",
103+
"OVStableDiffusion3Pipeline",
104+
"OVStableDiffusion3Image2ImagePipeline",
105+
"OVStableDiffusion3InpaintPipeline",
103106
"OVLatentConsistencyModelPipeline",
104107
"OVLatentConsistencyModelImg2ImgPipeline",
105108
"OVPipelineForImage2Image",
@@ -116,6 +119,9 @@
116119
"OVStableDiffusionXLPipeline",
117120
"OVStableDiffusionXLImg2ImgPipeline",
118121
"OVStableDiffusionXLInpaintPipeline",
122+
"OVStableDiffusion3Pipeline",
123+
"OVStableDiffusion3Image2ImagePipeline",
124+
"OVStableDiffusion3InpaintPipeline",
119125
"OVLatentConsistencyModelPipeline",
120126
"OVLatentConsistencyModelImg2ImgPipeline",
121127
"OVPipelineForImage2Image",

tests/openvino/test_export.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
OVModelForSequenceClassification,
4242
OVModelForSpeechSeq2Seq,
4343
OVModelForTokenClassification,
44+
OVStableDiffusion3Pipeline,
4445
OVStableDiffusionPipeline,
4546
OVStableDiffusionXLImg2ImgPipeline,
4647
OVStableDiffusionXLPipeline,
@@ -68,6 +69,7 @@ class ExportModelTest(unittest.TestCase):
6869
"stable-diffusion-xl": OVStableDiffusionXLPipeline,
6970
"stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,
7071
"latent-consistency": OVLatentConsistencyModelPipeline,
72+
"stable-diffusion-3": OVStableDiffusion3Pipeline,
7173
}
7274

7375
GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")

tests/openvino/test_exporters_cli.py

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class OVCLIExportTestCase(unittest.TestCase):
7171
("feature-extraction", "blenderbot"),
7272
("text-to-image", "stable-diffusion"),
7373
("text-to-image", "stable-diffusion-xl"),
74+
("text-to-image", "stable-diffusion-3"),
7475
("image-to-image", "stable-diffusion-xl-refiner"),
7576
)
7677
EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
@@ -85,6 +86,7 @@ class OVCLIExportTestCase(unittest.TestCase):
8586
"blenderbot": 2 if is_tokenizers_version("<", "0.20") else 0,
8687
"stable-diffusion": 2 if is_tokenizers_version("<", "0.20") else 0,
8788
"stable-diffusion-xl": 4 if is_tokenizers_version("<", "0.20") else 0,
89+
"stable-diffusion-3": 6 if is_tokenizers_version("<", "0.20") else 0,
8890
}
8991

9092
SUPPORTED_SD_HYBRID_ARCHITECTURES = (

0 commit comments

Comments
 (0)