Skip to content

Commit 726191f

Browse files
authored
update setting activation scale for diffusers (#1110)
* update setting activation scale for diffusers * fix style * apply comments
1 parent 878b474 commit 726191f

File tree

4 files changed

+104
-72
lines changed

4 files changed

+104
-72
lines changed

optimum/exporters/openvino/convert.py

+29-41
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
_torch_version,
4747
_transformers_version,
4848
compare_versions,
49-
is_diffusers_version,
5049
is_openvino_tokenizers_version,
5150
is_openvino_version,
5251
is_tokenizers_version,
@@ -104,10 +103,10 @@ def _set_runtime_options(
104103
):
105104
for model_name in models_and_export_configs.keys():
106105
_, sub_export_config = models_and_export_configs[model_name]
107-
sub_export_config.runtime_options = {}
106+
if not hasattr(sub_export_config, "runtime_options"):
107+
sub_export_config.runtime_options = {}
108108
if (
109-
"diffusers" in library_name
110-
or "text-generation" in task
109+
"text-generation" in task
111110
or ("image-text-to-text" in task and model_name == "language_model")
112111
or getattr(sub_export_config, "stateful", False)
113112
):
@@ -1014,45 +1013,29 @@ def _get_submodels_and_export_configs(
10141013
def get_diffusion_models_for_export_ext(
10151014
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
10161015
):
1017-
if is_diffusers_version(">=", "0.29.0"):
1018-
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
1019-
1020-
sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
1021-
if is_diffusers_version(">=", "0.30.0"):
1022-
from diffusers import StableDiffusion3InpaintPipeline
1023-
1024-
sd3_pipes.append(StableDiffusion3InpaintPipeline)
1025-
1026-
is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
1027-
else:
1028-
is_sd3 = False
1029-
1030-
if is_diffusers_version(">=", "0.30.0"):
1031-
from diffusers import FluxPipeline
1032-
1033-
flux_pipes = [FluxPipeline]
1034-
1035-
if is_diffusers_version(">=", "0.31.0"):
1036-
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline
1037-
1038-
flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])
1039-
1040-
if is_diffusers_version(">=", "0.32.0"):
1041-
from diffusers import FluxFillPipeline
1042-
1043-
flux_pipes.append(FluxFillPipeline)
1044-
1045-
is_flux = isinstance(pipeline, tuple(flux_pipes))
1046-
else:
1047-
is_flux = False
1048-
1049-
if not is_sd3 and not is_flux:
1050-
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
1051-
if is_sd3:
1016+
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
1017+
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
1018+
is_flux = pipeline.__class__.__name__.startswith("Flux")
1019+
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
1020+
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")
1021+
1022+
if is_sd or is_sdxl or is_lcm:
1023+
models_for_export = get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
1024+
if is_sdxl and pipeline.vae.config.force_upcast:
1025+
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}
1026+
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}
1027+
1028+
# only SD 2.1 has overflow issue, it uses different prediction_type than other models
1029+
if is_sd and pipeline.scheduler.config.prediction_type == "v_prediction":
1030+
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1031+
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1032+
1033+
elif is_sd3:
10521034
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1053-
else:
1035+
elif is_flux:
10541036
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1055-
1037+
else:
1038+
raise ValueError(f"Unsupported pipeline type `{pipeline.__class__.__name__}` provided")
10561039
return None, models_for_export
10571040

10581041

@@ -1150,6 +1133,7 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11501133
int_dtype=int_dtype,
11511134
float_dtype=float_dtype,
11521135
)
1136+
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
11531137
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
11541138

11551139
return models_for_export
@@ -1187,6 +1171,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11871171
transformer_export_config = export_config_constructor(
11881172
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
11891173
)
1174+
transformer_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
11901175
models_for_export["transformer"] = (transformer, transformer_export_config)
11911176

11921177
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
@@ -1202,6 +1187,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12021187
vae_encoder_export_config = vae_config_constructor(
12031188
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
12041189
)
1190+
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
12051191
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
12061192

12071193
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
@@ -1217,6 +1203,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12171203
vae_decoder_export_config = vae_config_constructor(
12181204
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
12191205
)
1206+
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
12201207
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
12211208

12221209
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
@@ -1233,6 +1220,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12331220
int_dtype=int_dtype,
12341221
float_dtype=float_dtype,
12351222
)
1223+
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
12361224
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
12371225

12381226
return models_for_export

optimum/intel/openvino/modeling_diffusion.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
)
6464

6565
from ...exporters.openvino import main_export
66-
from ..utils.import_utils import is_diffusers_version
66+
from ..utils.import_utils import is_diffusers_version, is_openvino_version
6767
from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig
6868
from .loaders import OVTextualInversionLoaderMixin
6969
from .modeling_base import OVBaseModel
@@ -73,6 +73,7 @@
7373
OV_XML_FILE_NAME,
7474
TemporaryDirectory,
7575
_print_compiled_model_properties,
76+
check_scale_available,
7677
model_has_dynamic_inputs,
7778
np_to_pt_generators,
7879
)
@@ -484,8 +485,15 @@ def _from_pretrained(
484485
ov_config = kwargs.get("ov_config", {})
485486
device = kwargs.get("device", "CPU")
486487
vae_ov_conifg = {**ov_config}
487-
if "GPU" in device.upper() and "INFERENCE_PRECISION_HINT" not in vae_ov_conifg:
488-
vae_ov_conifg["INFERENCE_PRECISION_HINT"] = "f32"
488+
if (
489+
"GPU" in device.upper()
490+
and "INFERENCE_PRECISION_HINT" not in vae_ov_conifg
491+
and is_openvino_version("<=", "2025.0")
492+
):
493+
vae_model_path = models["vae_decoder"]
494+
required_upcast = check_scale_available(vae_model_path)
495+
if required_upcast:
496+
vae_ov_conifg["INFERENCE_PRECISION_HINT"] = "f32"
489497
for name, path in models.items():
490498
if name in kwargs:
491499
models[name] = kwargs.pop(name)
@@ -1202,7 +1210,12 @@ def forward(
12021210
return ModelOutput(**model_outputs)
12031211

12041212
def _compile(self):
1205-
if "GPU" in self._device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
1213+
if (
1214+
"GPU" in self._device
1215+
and "INFERENCE_PRECISION_HINT" not in self.ov_config
1216+
and is_openvino_version("<", "2025.0")
1217+
and check_scale_available(self.model)
1218+
):
12061219
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
12071220
super()._compile()
12081221

@@ -1241,7 +1254,12 @@ def forward(
12411254
return ModelOutput(**model_outputs)
12421255

12431256
def _compile(self):
1244-
if "GPU" in self._device and "INFERENCE_PRECISION_HINT" not in self.ov_config:
1257+
if (
1258+
"GPU" in self._device
1259+
and "INFERENCE_PRECISION_HINT" not in self.ov_config
1260+
and is_openvino_version("<", "2025.0")
1261+
and check_scale_available(self.model)
1262+
):
12451263
self.ov_config.update({"INFERENCE_PRECISION_HINT": "f32"})
12461264
super()._compile()
12471265

optimum/intel/openvino/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,21 @@ def onexc(func, path, exc):
565565
def cleanup(self):
566566
if self._finalizer.detach() or os.path.exists(self.name):
567567
self._rmtree(self.name, ignore_errors=self._ignore_cleanup_errors)
568+
569+
570+
def check_scale_available(model: Union[Model, str, Path]):
571+
if isinstance(model, Model):
572+
return model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
573+
if not Path(model).exists():
574+
return False
575+
import xml.etree.ElementTree as ET
576+
577+
tree = ET.parse(model)
578+
root = tree.getroot()
579+
rt_info = root.find("rt_info")
580+
if rt_info is None:
581+
return False
582+
runtime_options = rt_info.find("runtime_options")
583+
if runtime_options is None:
584+
return False
585+
return runtime_options.find("ACTIVATIONS_SCALE_FACTOR") is not None

tests/openvino/test_export.py

+34-26
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ class ExportModelTest(unittest.TestCase):
7575
"llava": OVModelForVisualCausalLM,
7676
}
7777

78+
EXPECTED_DIFFUSERS_SCALE_FACTORS = {
79+
"stable-diffusion-xl": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
80+
"stable-diffusion-3": {"text_encoder_3": "8.0"},
81+
"flux": {"text_encoder_2": "8.0", "transformer": "8.0", "vae_encoder": "8.0", "vae_decoder": "8.0"},
82+
"stable-diffusion-xl-refiner": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
83+
}
84+
7885
if is_transformers_version(">=", "4.45"):
7986
SUPPORTED_ARCHITECTURES.update({"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline})
8087

@@ -143,32 +150,33 @@ def _openvino_export(
143150
)
144151

145152
if library_name == "diffusers":
146-
self.assertTrue(
147-
ov_model.vae_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
148-
)
149-
self.assertTrue(
150-
ov_model.vae_decoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
151-
)
152-
if hasattr(ov_model, "text_encoder") and ov_model.text_encoder:
153-
self.assertTrue(
154-
ov_model.text_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
155-
)
156-
if hasattr(ov_model, "text_encoder_2") and ov_model.text_encoder_2:
157-
self.assertTrue(
158-
ov_model.text_encoder_2.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
159-
)
160-
if hasattr(ov_model, "text_encoder_3") and ov_model.text_encoder_3:
161-
self.assertTrue(
162-
ov_model.text_encoder_3.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
163-
)
164-
if hasattr(ov_model, "unet") and ov_model.unet:
165-
self.assertTrue(
166-
ov_model.unet.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
167-
)
168-
if hasattr(ov_model, "transformer") and ov_model.transformer:
169-
self.assertTrue(
170-
ov_model.transformer.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
171-
)
153+
expected_scale_factors = self.EXPECTED_DIFFUSERS_SCALE_FACTORS.get(model_type, {})
154+
components = [
155+
"unet",
156+
"transformer",
157+
"text_encoder",
158+
"text_encoder_2",
159+
"text_encoder_3",
160+
"vae_encoder",
161+
"vae_decoder",
162+
]
163+
for component in components:
164+
component_model = getattr(ov_model, component, None)
165+
if component_model is None:
166+
continue
167+
component_scale = expected_scale_factors.get(component)
168+
if component_scale is not None:
169+
self.assertTrue(
170+
component_model.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
171+
)
172+
self.assertEqual(
173+
component_model.model.get_rt_info()["runtime_options"]["ACTIVATIONS_SCALE_FACTOR"],
174+
component_scale,
175+
)
176+
else:
177+
self.assertFalse(
178+
component_model.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
179+
)
172180

173181
@parameterized.expand(SUPPORTED_ARCHITECTURES)
174182
def test_export(self, model_type: str):

0 commit comments

Comments
 (0)