57
57
)
58
58
59
59
from ...exporters .openvino import main_export
60
- from .configuration import OVConfig , OVWeightQuantizationConfig
60
+ from .configuration import OVConfig , OVQuantizationMethod , OVWeightQuantizationConfig
61
61
from .loaders import OVTextualInversionLoaderMixin
62
62
from .modeling_base import OVBaseModel
63
63
from .utils import (
64
64
ONNX_WEIGHTS_NAME ,
65
65
OV_TO_NP_TYPE ,
66
66
OV_XML_FILE_NAME ,
67
- PREDEFINED_SD_DATASETS ,
68
67
_print_compiled_model_properties ,
69
68
)
70
69
@@ -293,35 +292,27 @@ def _from_pretrained(
293
292
else :
294
293
kwargs [name ] = load_method (new_model_save_dir )
295
294
296
- quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
297
-
298
295
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
299
- if quantization_config is not None and quantization_config .dataset is not None :
300
- # load the UNet model uncompressed to apply hybrid quantization further
301
- unet = cls .load_model (unet_path )
302
- # Apply weights compression to other `components` without dataset
303
- weight_quantization_params = {
304
- param : value for param , value in quantization_config .__dict__ .items () if param != "dataset"
305
- }
306
- weight_quantization_config = OVWeightQuantizationConfig .from_dict (weight_quantization_params )
307
- else :
308
- weight_quantization_config = quantization_config
309
- unet = cls .load_model (unet_path , weight_quantization_config )
310
-
311
296
components = {
312
297
"vae_encoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name ,
313
298
"vae_decoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name ,
314
299
"text_encoder" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name ,
315
300
"text_encoder_2" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name ,
316
301
}
317
302
318
- for key , value in components .items ():
319
- components [key ] = cls .load_model (value , weight_quantization_config ) if value .is_file () else None
320
-
321
303
if model_save_dir is None :
322
304
model_save_dir = new_model_save_dir
323
305
324
- if quantization_config is not None and quantization_config .dataset is not None :
306
+ quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
307
+ if quantization_config is None or quantization_config .dataset is None :
308
+ unet = cls .load_model (unet_path , quantization_config )
309
+ for key , value in components .items ():
310
+ components [key ] = cls .load_model (value , quantization_config ) if value .is_file () else None
311
+ else :
312
+ # Load uncompressed models to apply hybrid quantization further
313
+ unet = cls .load_model (unet_path )
314
+ for key , value in components .items ():
315
+ components [key ] = cls .load_model (value ) if value .is_file () else None
325
316
sd_model = cls (unet = unet , config = config , model_save_dir = model_save_dir , ** components , ** kwargs )
326
317
327
318
supported_pipelines = (
@@ -332,12 +323,14 @@ def _from_pretrained(
332
323
if not isinstance (sd_model , supported_pipelines ):
333
324
raise NotImplementedError (f"Quantization in hybrid mode is not supported for { cls .__name__ } " )
334
325
335
- nsamples = quantization_config .num_samples if quantization_config .num_samples else 200
336
- unet_inputs = sd_model ._prepare_unet_inputs (quantization_config .dataset , nsamples )
326
+ from optimum .intel import OVQuantizer
337
327
338
- from .quantization import _hybrid_quantization
328
+ hybrid_quantization_config = deepcopy (quantization_config )
329
+ hybrid_quantization_config .quant_method = OVQuantizationMethod .HYBRID
330
+ quantizer = OVQuantizer (sd_model )
331
+ quantizer .quantize (ov_config = OVConfig (quantization_config = hybrid_quantization_config ))
339
332
340
- unet = _hybrid_quantization ( sd_model . unet . model , weight_quantization_config , dataset = unet_inputs )
333
+ return sd_model
341
334
342
335
return cls (
343
336
unet = unet ,
@@ -348,62 +341,6 @@ def _from_pretrained(
348
341
** kwargs ,
349
342
)
350
343
351
- def _prepare_unet_inputs (
352
- self ,
353
- dataset : Union [str , List [Any ]],
354
- num_samples : int ,
355
- height : Optional [int ] = None ,
356
- width : Optional [int ] = None ,
357
- seed : Optional [int ] = 42 ,
358
- ** kwargs ,
359
- ) -> Dict [str , Any ]:
360
- self .compile ()
361
-
362
- size = self .unet .config .get ("sample_size" , 64 ) * self .vae_scale_factor
363
- height = height or min (size , 512 )
364
- width = width or min (size , 512 )
365
-
366
- if isinstance (dataset , str ):
367
- dataset = deepcopy (dataset )
368
- available_datasets = PREDEFINED_SD_DATASETS .keys ()
369
- if dataset not in available_datasets :
370
- raise ValueError (
371
- f"""You have entered a string value for dataset. You can only choose between
372
- { list (available_datasets )} , but the { dataset } was found"""
373
- )
374
-
375
- from datasets import load_dataset
376
-
377
- dataset_metadata = PREDEFINED_SD_DATASETS [dataset ]
378
- dataset = load_dataset (dataset , split = dataset_metadata ["split" ], streaming = True ).shuffle (seed = seed )
379
- input_names = dataset_metadata ["inputs" ]
380
- dataset = dataset .select_columns (list (input_names .values ()))
381
-
382
- def transform_fn (data_item ):
383
- return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
384
-
385
- else :
386
-
387
- def transform_fn (data_item ):
388
- return data_item if isinstance (data_item , (list , dict )) else [data_item ]
389
-
390
- from .quantization import InferRequestWrapper
391
-
392
- calibration_data = []
393
- self .unet .request = InferRequestWrapper (self .unet .request , calibration_data )
394
-
395
- for inputs in dataset :
396
- inputs = transform_fn (inputs )
397
- if isinstance (inputs , dict ):
398
- self .__call__ (** inputs , height = height , width = width )
399
- else :
400
- self .__call__ (* inputs , height = height , width = width )
401
- if len (calibration_data ) >= num_samples :
402
- break
403
-
404
- self .unet .request = self .unet .request .request
405
- return calibration_data [:num_samples ]
406
-
407
344
@classmethod
408
345
def _from_transformers (
409
346
cls ,
0 commit comments