Skip to content

Commit b78881a

Browse files
committed
add test
1 parent 95e3c82 commit b78881a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

optimum/intel/openvino/modeling_diffusion.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
294294
dst_path = save_path / OV_XML_FILE_NAME
295295
dst_path.parent.mkdir(parents=True, exist_ok=True)
296296
openvino.save_model(model.model, dst_path, compress_to_fp16=False)
297-
model_dir = self.config.get("_name_or_path", None) or model.model_save_dir
297+
model_dir = (
298+
self.model_save_dir
299+
if not isinstance(self.model_save_dir, TemporaryDirectory)
300+
else self.model_save_dir.name
301+
)
298302
config_path = Path(model_dir) / save_path.name / CONFIG_NAME
299303
if config_path.is_file():
300304
config_save_path = save_path / CONFIG_NAME

tests/openvino/test_diffusion.py

+7
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,13 @@ def test_load_and_save_pipeline_with_safety_checker(self):
359359
self.assertTrue(model_lib in ["diffusers", "transformers"])
360360
self.assertFalse(model_class.startswith("OV"))
361361
loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname)
362+
for component in ["text_encoder", "unet", "vae_encoder", "vae_decoder"]:
363+
config = getattr(getattr(ov_pipeline, component), "config", None)
364+
if config is not None:
365+
loaded_config = getattr(getattr(loaded_pipeline, component), "config")
366+
self.assertDictEqual(
367+
config, loaded_config, f"Expected config:\n{config}\nLoaded config:|n{loaded_config}"
368+
)
362369
self.assertTrue(loaded_pipeline.safety_checker is not None)
363370
self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker)
364371
del loaded_pipeline

0 commit comments

Comments
 (0)