diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py
index 8cfc9529a2..d1204ca78d 100644
--- a/optimum/intel/openvino/modeling_diffusion.py
+++ b/optimum/intel/openvino/modeling_diffusion.py
@@ -162,10 +162,11 @@ def __init__(
                     "Please provide `compile=True` if you want to use `compile_only=True` or set `compile_only=False`"
                 )
 
-            if not isinstance(unet, openvino.runtime.CompiledModel):
+            main_model = unet if unet is not None else transformer
+            if not isinstance(main_model, openvino.runtime.CompiledModel):
                 raise ValueError("`compile_only` expect that already compiled model will be provided")
 
-            model_is_dynamic = model_has_dynamic_inputs(unet)
+            model_is_dynamic = model_has_dynamic_inputs(main_model)
             if dynamic_shapes ^ model_is_dynamic:
                 requested_shapes = "dynamic" if dynamic_shapes else "static"
                 compiled_shapes = "dynamic" if model_is_dynamic else "static"
@@ -291,6 +292,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
                 if config_path.is_file():
                     config_save_path = save_path / CONFIG_NAME
                     shutil.copyfile(config_path, config_save_path)
+                else:
+                    if hasattr(model, "save_config"):
+                        model.save_config(save_path)
+                    elif hasattr(model, "config") and hasattr(model.config, "save_pretrained"):
+                        model.config.save_pretrained(save_path)
 
         self.scheduler.save_pretrained(save_directory / "scheduler")
 
diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py
index 2a597276f1..0fe43d71e0 100644
--- a/tests/openvino/test_modeling.py
+++ b/tests/openvino/test_modeling.py
@@ -67,6 +67,8 @@
 
 from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
 from optimum.intel import (
+    OVDiffusionPipeline,
+    OVFluxPipeline,
     OVModelForAudioClassification,
     OVModelForAudioFrameClassification,
     OVModelForAudioXVector,
@@ -107,7 +109,9 @@
 from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
 from optimum.intel.utils.modeling_utils import _find_files_matching_pattern
 from optimum.utils import (
+    DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
     DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
+    DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER,
     DIFFUSION_MODEL_UNET_SUBFOLDER,
     DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
     DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
@@ -140,7 +144,8 @@ def __init__(self, *args, **kwargs):
         self.OV_MODEL_ID = "echarlaix/distilbert-base-uncased-finetuned-sst-2-english-openvino"
         self.OV_DECODER_MODEL_ID = "helenai/gpt2-ov"
         self.OV_SEQ2SEQ_MODEL_ID = "echarlaix/t5-small-openvino"
-        self.OV_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-openvino"
+        self.OV_SD_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-openvino"
+        self.OV_FLUX_DIFFUSION_MODEL_ID = "katuni4ka/tiny-random-flux-ov"
         self.OV_VLM_MODEL_ID = "katuni4ka/tiny-random-llava-ov"
 
     def test_load_from_hub_and_save_model(self):
@@ -337,7 +342,7 @@ def test_load_from_hub_and_save_seq2seq_model(self):
 
     @require_diffusers
     def test_load_from_hub_and_save_stable_diffusion_model(self):
-        loaded_pipeline = OVStableDiffusionPipeline.from_pretrained(self.OV_DIFFUSION_MODEL_ID, compile=False)
+        loaded_pipeline = OVStableDiffusionPipeline.from_pretrained(self.OV_SD_DIFFUSION_MODEL_ID, compile=False)
         self.assertIsInstance(loaded_pipeline.config, Dict)
         # Test that PERFORMANCE_HINT is set to LATENCY by default
         self.assertEqual(loaded_pipeline.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
@@ -391,6 +396,72 @@ def test_load_from_hub_and_save_stable_diffusion_model(self):
         del pipeline
         gc.collect()
 
+    @require_diffusers
+    @unittest.skipIf(
+        is_transformers_version("<", "4.45"),
+        "model tokenizer exported with tokenizers 0.20 is not compatible with old transformers",
+    )
+    def test_load_from_hub_and_save_flux_model(self):
+        loaded_pipeline = OVDiffusionPipeline.from_pretrained(self.OV_FLUX_DIFFUSION_MODEL_ID, compile=False)
+        self.assertIsInstance(loaded_pipeline, OVFluxPipeline)
+        self.assertIsInstance(loaded_pipeline.config, Dict)
+        # Test that PERFORMANCE_HINT is set to LATENCY by default
+        self.assertEqual(loaded_pipeline.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
+        loaded_pipeline.compile()
+        self.assertIsNone(loaded_pipeline.unet)
+        self.assertEqual(loaded_pipeline.transformer.request.get_property("PERFORMANCE_HINT"), "LATENCY")
+        batch_size, height, width = 2, 16, 16
+        inputs = {
+            "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size,
+            "height": height,
+            "width": width,
+            "num_inference_steps": 2,
+            "output_type": "np",
+        }
+
+        np.random.seed(0)
+        torch.manual_seed(0)
+        pipeline_outputs = loaded_pipeline(**inputs).images
+        self.assertEqual(pipeline_outputs.shape, (batch_size, height, width, 3))
+
+        with TemporaryDirectory() as tmpdirname:
+            loaded_pipeline.save_pretrained(tmpdirname)
+            pipeline = OVDiffusionPipeline.from_pretrained(tmpdirname)
+            self.assertIsInstance(loaded_pipeline, OVFluxPipeline)
+            folder_contents = os.listdir(tmpdirname)
+            self.assertIn(loaded_pipeline.config_name, folder_contents)
+            for subfoler in {
+                DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER,
+                DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
+                DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
+                DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
+                DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
+            }:
+                folder_contents = os.listdir(os.path.join(tmpdirname, subfoler))
+                self.assertIn(OV_XML_FILE_NAME, folder_contents)
+                self.assertIn(OV_XML_FILE_NAME.replace(".xml", ".bin"), folder_contents)
+
+            compile_only_pipeline = OVDiffusionPipeline.from_pretrained(tmpdirname, compile_only=True)
+            self.assertIsInstance(compile_only_pipeline, OVFluxPipeline)
+            self.assertIsInstance(compile_only_pipeline.transformer.model, ov.runtime.CompiledModel)
+            self.assertIsInstance(compile_only_pipeline.text_encoder.model, ov.runtime.CompiledModel)
+            self.assertIsInstance(compile_only_pipeline.text_encoder_2.model, ov.runtime.CompiledModel)
+            self.assertIsInstance(compile_only_pipeline.vae_encoder.model, ov.runtime.CompiledModel)
+            self.assertIsInstance(compile_only_pipeline.vae_decoder.model, ov.runtime.CompiledModel)
+
+            np.random.seed(0)
+            torch.manual_seed(0)
+            outputs = compile_only_pipeline(**inputs).images
+            np.testing.assert_allclose(pipeline_outputs, outputs, atol=1e-4, rtol=1e-4)
+            del compile_only_pipeline
+
+        np.random.seed(0)
+        torch.manual_seed(0)
+        outputs = pipeline(**inputs).images
+        np.testing.assert_allclose(pipeline_outputs, outputs, atol=1e-4, rtol=1e-4)
+        del pipeline
+        gc.collect()
+
     @pytest.mark.run_slow
     @slow
     def test_load_model_from_hub_private_with_token(self):