Skip to content

Commit a70255d

Browse files
authored
Fix compile_only mode for diffusers with transformer as main model (#1101)
* add test * config saving from model
1 parent 7d7de7c commit a70255d

File tree

2 files changed

+81
-4
lines changed

2 files changed

+81
-4
lines changed

optimum/intel/openvino/modeling_diffusion.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ def __init__(
162162
"Please provide `compile=True` if you want to use `compile_only=True` or set `compile_only=False`"
163163
)
164164

165-
if not isinstance(unet, openvino.runtime.CompiledModel):
165+
main_model = unet if unet is not None else transformer
166+
if not isinstance(main_model, openvino.runtime.CompiledModel):
166167
raise ValueError("`compile_only` expect that already compiled model will be provided")
167168

168-
model_is_dynamic = model_has_dynamic_inputs(unet)
169+
model_is_dynamic = model_has_dynamic_inputs(main_model)
169170
if dynamic_shapes ^ model_is_dynamic:
170171
requested_shapes = "dynamic" if dynamic_shapes else "static"
171172
compiled_shapes = "dynamic" if model_is_dynamic else "static"
@@ -291,6 +292,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
291292
if config_path.is_file():
292293
config_save_path = save_path / CONFIG_NAME
293294
shutil.copyfile(config_path, config_save_path)
295+
else:
296+
if hasattr(model, "save_config"):
297+
model.save_config(save_path)
298+
elif hasattr(model, "config") and hasattr(model.config, "save_pretrained"):
299+
model.config.save_pretrained(save_path)
294300

295301
self.scheduler.save_pretrained(save_directory / "scheduler")
296302

tests/openvino/test_modeling.py

+73-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767

6868
from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
6969
from optimum.intel import (
70+
OVDiffusionPipeline,
71+
OVFluxPipeline,
7072
OVModelForAudioClassification,
7173
OVModelForAudioFrameClassification,
7274
OVModelForAudioXVector,
@@ -107,7 +109,9 @@
107109
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
108110
from optimum.intel.utils.modeling_utils import _find_files_matching_pattern
109111
from optimum.utils import (
112+
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
110113
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
114+
DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER,
111115
DIFFUSION_MODEL_UNET_SUBFOLDER,
112116
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
113117
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
@@ -140,7 +144,8 @@ def __init__(self, *args, **kwargs):
140144
self.OV_MODEL_ID = "echarlaix/distilbert-base-uncased-finetuned-sst-2-english-openvino"
141145
self.OV_DECODER_MODEL_ID = "helenai/gpt2-ov"
142146
self.OV_SEQ2SEQ_MODEL_ID = "echarlaix/t5-small-openvino"
143-
self.OV_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-openvino"
147+
self.OV_SD_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-openvino"
148+
self.OV_FLUX_DIFFUSION_MODEL_ID = "katuni4ka/tiny-random-flux-ov"
144149
self.OV_VLM_MODEL_ID = "katuni4ka/tiny-random-llava-ov"
145150

146151
def test_load_from_hub_and_save_model(self):
@@ -337,7 +342,7 @@ def test_load_from_hub_and_save_seq2seq_model(self):
337342

338343
@require_diffusers
339344
def test_load_from_hub_and_save_stable_diffusion_model(self):
340-
loaded_pipeline = OVStableDiffusionPipeline.from_pretrained(self.OV_DIFFUSION_MODEL_ID, compile=False)
345+
loaded_pipeline = OVStableDiffusionPipeline.from_pretrained(self.OV_SD_DIFFUSION_MODEL_ID, compile=False)
341346
self.assertIsInstance(loaded_pipeline.config, Dict)
342347
# Test that PERFORMANCE_HINT is set to LATENCY by default
343348
self.assertEqual(loaded_pipeline.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
@@ -391,6 +396,72 @@ def test_load_from_hub_and_save_stable_diffusion_model(self):
391396
del pipeline
392397
gc.collect()
393398

399+
@require_diffusers
400+
@unittest.skipIf(
401+
is_transformers_version("<", "4.45"),
402+
"model tokenizer exported with tokenizers 0.20 is not compatible with old transformers",
403+
)
404+
def test_load_from_hub_and_save_flux_model(self):
405+
loaded_pipeline = OVDiffusionPipeline.from_pretrained(self.OV_FLUX_DIFFUSION_MODEL_ID, compile=False)
406+
self.assertIsInstance(loaded_pipeline, OVFluxPipeline)
407+
self.assertIsInstance(loaded_pipeline.config, Dict)
408+
# Test that PERFORMANCE_HINT is set to LATENCY by default
409+
self.assertEqual(loaded_pipeline.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
410+
loaded_pipeline.compile()
411+
self.assertIsNone(loaded_pipeline.unet)
412+
self.assertEqual(loaded_pipeline.transformer.request.get_property("PERFORMANCE_HINT"), "LATENCY")
413+
batch_size, height, width = 2, 16, 16
414+
inputs = {
415+
"prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size,
416+
"height": height,
417+
"width": width,
418+
"num_inference_steps": 2,
419+
"output_type": "np",
420+
}
421+
422+
np.random.seed(0)
423+
torch.manual_seed(0)
424+
pipeline_outputs = loaded_pipeline(**inputs).images
425+
self.assertEqual(pipeline_outputs.shape, (batch_size, height, width, 3))
426+
427+
with TemporaryDirectory() as tmpdirname:
428+
loaded_pipeline.save_pretrained(tmpdirname)
429+
pipeline = OVDiffusionPipeline.from_pretrained(tmpdirname)
430+
self.assertIsInstance(loaded_pipeline, OVFluxPipeline)
431+
folder_contents = os.listdir(tmpdirname)
432+
self.assertIn(loaded_pipeline.config_name, folder_contents)
433+
for subfoler in {
434+
DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER,
435+
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
436+
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
437+
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
438+
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
439+
}:
440+
folder_contents = os.listdir(os.path.join(tmpdirname, subfoler))
441+
self.assertIn(OV_XML_FILE_NAME, folder_contents)
442+
self.assertIn(OV_XML_FILE_NAME.replace(".xml", ".bin"), folder_contents)
443+
444+
compile_only_pipeline = OVDiffusionPipeline.from_pretrained(tmpdirname, compile_only=True)
445+
self.assertIsInstance(compile_only_pipeline, OVFluxPipeline)
446+
self.assertIsInstance(compile_only_pipeline.transformer.model, ov.runtime.CompiledModel)
447+
self.assertIsInstance(compile_only_pipeline.text_encoder.model, ov.runtime.CompiledModel)
448+
self.assertIsInstance(compile_only_pipeline.text_encoder_2.model, ov.runtime.CompiledModel)
449+
self.assertIsInstance(compile_only_pipeline.vae_encoder.model, ov.runtime.CompiledModel)
450+
self.assertIsInstance(compile_only_pipeline.vae_decoder.model, ov.runtime.CompiledModel)
451+
452+
np.random.seed(0)
453+
torch.manual_seed(0)
454+
outputs = compile_only_pipeline(**inputs).images
455+
np.testing.assert_allclose(pipeline_outputs, outputs, atol=1e-4, rtol=1e-4)
456+
del compile_only_pipeline
457+
458+
np.random.seed(0)
459+
torch.manual_seed(0)
460+
outputs = pipeline(**inputs).images
461+
np.testing.assert_allclose(pipeline_outputs, outputs, atol=1e-4, rtol=1e-4)
462+
del pipeline
463+
gc.collect()
464+
394465
@pytest.mark.run_slow
395466
@slow
396467
def test_load_model_from_hub_private_with_token(self):

0 commit comments

Comments
 (0)