Skip to content

Commit 636a613

Browse files
committed
reformat
1 parent 2dc4087 commit 636a613

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

optimum/intel/openvino/modeling_diffusion.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,13 @@ def _from_pretrained(
287287
# load the UNet model uncompressed to apply hybrid quantization further
288288
unet = cls.load_model(unet_path)
289289
# Apply weights compression to other `components` without dataset
290-
q_config_params = quantization_config.__dict__
291-
wc_params = {param: value for param, value in q_config_params.items() if param != "dataset"}
292-
wc_quantization_config = OVWeightQuantizationConfig.from_dict(wc_params)
290+
weight_quantization_params = {
291+
param: value for param, value in quantization_config.__dict__.items() if param != "dataset"
292+
}
293+
weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params)
293294
else:
294-
wc_quantization_config = quantization_config
295-
unet = cls.load_model(unet_path, wc_quantization_config)
295+
weight_quantization_config = quantization_config
296+
unet = cls.load_model(unet_path, weight_quantization_config)
296297

297298
components = {
298299
"vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
@@ -302,7 +303,7 @@ def _from_pretrained(
302303
}
303304

304305
for key, value in components.items():
305-
components[key] = cls.load_model(value, wc_quantization_config) if value.is_file() else None
306+
components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None
306307

307308
if model_save_dir is None:
308309
model_save_dir = new_model_save_dir
@@ -323,7 +324,7 @@ def _from_pretrained(
323324

324325
from .quantization import _hybrid_quantization
325326

326-
unet = _hybrid_quantization(sd_model.unet.model, wc_quantization_config, dataset=unet_inputs)
327+
unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs)
327328

328329
return cls(
329330
unet=unet,

0 commit comments

Comments
 (0)