@@ -282,16 +282,17 @@ def _from_pretrained(
282
282
283
283
quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
284
284
285
- dataset = None
286
285
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
287
286
if quantization_config is not None and quantization_config .dataset is not None :
288
- dataset = quantization_config .dataset
289
287
# load the UNet model uncompressed to apply hybrid quantization further
290
288
unet = cls .load_model (unet_path )
291
289
# Apply weights compression to other `components` without dataset
292
- quantization_config .dataset = None
290
+ q_config_params = quantization_config .__dict__
291
+ wc_params = {param : value for param , value in q_config_params .items () if param != "dataset" }
292
+ wc_quantization_config = OVWeightQuantizationConfig .from_dict (wc_params )
293
293
else :
294
- unet = cls .load_model (unet_path , quantization_config )
294
+ wc_quantization_config = quantization_config
295
+ unet = cls .load_model (unet_path , wc_quantization_config )
295
296
296
297
components = {
297
298
"vae_encoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name ,
@@ -301,12 +302,12 @@ def _from_pretrained(
301
302
}
302
303
303
304
for key , value in components .items ():
304
- components [key ] = cls .load_model (value , quantization_config ) if value .is_file () else None
305
+ components [key ] = cls .load_model (value , wc_quantization_config ) if value .is_file () else None
305
306
306
307
if model_save_dir is None :
307
308
model_save_dir = new_model_save_dir
308
309
309
- if dataset is not None :
310
+ if quantization_config is not None and quantization_config . dataset is not None :
310
311
sd_model = cls (unet = unet , config = config , model_save_dir = model_save_dir , ** components , ** kwargs )
311
312
312
313
supported_pipelines = (
@@ -318,12 +319,11 @@ def _from_pretrained(
318
319
raise NotImplementedError (f"Quantization in hybrid mode is not supported for { cls .__name__ } " )
319
320
320
321
nsamples = quantization_config .num_samples if quantization_config .num_samples else 200
321
- unet_inputs = sd_model ._prepare_unet_inputs (dataset , nsamples )
322
+ unet_inputs = sd_model ._prepare_unet_inputs (quantization_config . dataset , nsamples )
322
323
323
324
from .quantization import _hybrid_quantization
324
325
325
- unet = _hybrid_quantization (sd_model .unet .model , quantization_config , dataset = unet_inputs )
326
- quantization_config .dataset = dataset
326
+ unet = _hybrid_quantization (sd_model .unet .model , wc_quantization_config , dataset = unet_inputs )
327
327
328
328
return cls (
329
329
unet = unet ,
@@ -338,13 +338,17 @@ def _prepare_unet_inputs(
338
338
self ,
339
339
dataset : Union [str , List [Any ]],
340
340
num_samples : int ,
341
- height : Optional [int ] = 512 ,
342
- width : Optional [int ] = 512 ,
341
+ height : Optional [int ] = None ,
342
+ width : Optional [int ] = None ,
343
343
seed : Optional [int ] = 42 ,
344
344
** kwargs ,
345
345
) -> Dict [str , Any ]:
346
346
self .compile ()
347
347
348
+ size = self .unet .config .get ("sample_size" , 64 ) * self .vae_scale_factor
349
+ height = height or min (size , 512 )
350
+ width = width or min (size , 512 )
351
+
348
352
if isinstance (dataset , str ):
349
353
dataset = deepcopy (dataset )
350
354
available_datasets = PREDEFINED_SD_DATASETS .keys ()
0 commit comments