@@ -292,33 +292,27 @@ def _from_pretrained(
292
292
else :
293
293
kwargs [name ] = load_method (new_model_save_dir )
294
294
295
- quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
296
-
297
295
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
298
- if quantization_config is not None and quantization_config .dataset is not None :
299
- # load the UNet model uncompressed to apply hybrid quantization further
300
- unet = cls .load_model (unet_path )
301
- # Apply weights compression to other `components` without dataset
302
- quantization_config_without_dataset = deepcopy (quantization_config )
303
- quantization_config_without_dataset .dataset = None
304
- else :
305
- quantization_config_without_dataset = quantization_config
306
- unet = cls .load_model (unet_path , quantization_config_without_dataset )
307
-
308
296
components = {
309
297
"vae_encoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name ,
310
298
"vae_decoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name ,
311
299
"text_encoder" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name ,
312
300
"text_encoder_2" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name ,
313
301
}
314
302
315
- for key , value in components .items ():
316
- components [key ] = cls .load_model (value , quantization_config_without_dataset ) if value .is_file () else None
317
-
318
303
if model_save_dir is None :
319
304
model_save_dir = new_model_save_dir
320
305
321
- if quantization_config is not None and quantization_config .dataset is not None :
306
+ quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
307
+ if quantization_config is None or quantization_config .dataset is None :
308
+ unet = cls .load_model (unet_path , quantization_config )
309
+ for key , value in components .items ():
310
+ components [key ] = cls .load_model (value , quantization_config ) if value .is_file () else None
311
+ else :
312
+ # Load uncompressed models to apply hybrid quantization further
313
+ unet = cls .load_model (unet_path )
314
+ for key , value in components .items ():
315
+ components [key ] = cls .load_model (value ) if value .is_file () else None
322
316
sd_model = cls (unet = unet , config = config , model_save_dir = model_save_dir , ** components , ** kwargs )
323
317
324
318
supported_pipelines = (
@@ -331,10 +325,10 @@ def _from_pretrained(
331
325
332
326
from optimum .intel import OVQuantizer
333
327
328
+ hybrid_quantization_config = deepcopy (quantization_config )
329
+ hybrid_quantization_config .quant_method = OVQuantizationMethod .HYBRID
334
330
quantizer = OVQuantizer (sd_model )
335
- quantization_config_copy = deepcopy (quantization_config )
336
- quantization_config_copy .quant_method = OVQuantizationMethod .HYBRID
337
- quantizer .quantize (ov_config = OVConfig (quantization_config = quantization_config_copy ))
331
+ quantizer .quantize (ov_config = OVConfig (quantization_config = hybrid_quantization_config ))
338
332
339
333
return sd_model
340
334
@@ -347,6 +341,7 @@ def _from_pretrained(
347
341
** kwargs ,
348
342
)
349
343
344
+
350
345
@classmethod
351
346
def _from_transformers (
352
347
cls ,
0 commit comments