14
14
15
15
import importlib
16
16
import logging
17
- import math
18
17
import os
19
18
import shutil
20
19
from copy import deepcopy
59
58
from .configuration import OVConfig , OVWeightQuantizationConfig
60
59
from .loaders import OVTextualInversionLoaderMixin
61
60
from .modeling_base import OVBaseModel
62
- from .utils import ONNX_WEIGHTS_NAME , OV_TO_NP_TYPE , OV_XML_FILE_NAME , _print_compiled_model_properties
61
+ from .utils import (
62
+ ONNX_WEIGHTS_NAME ,
63
+ OV_TO_NP_TYPE ,
64
+ OV_XML_FILE_NAME ,
65
+ PREDEFINED_SD_DATASETS ,
66
+ _print_compiled_model_properties ,
67
+ )
63
68
64
69
65
70
core = Core ()
@@ -276,13 +281,15 @@ def _from_pretrained(
276
281
kwargs [name ] = load_method (new_model_save_dir )
277
282
278
283
quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
279
- weight_quantization_config = deepcopy (quantization_config )
284
+
285
+ dataset = None
280
286
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
281
- if weight_quantization_config is not None and weight_quantization_config .dataset is not None :
287
+ if quantization_config is not None and quantization_config .dataset is not None :
288
+ dataset = quantization_config .dataset
282
289
# load the UNet model uncompressed to apply hybrid quantization further
283
290
unet = cls .load_model (unet_path )
284
291
# Apply weights compression to other `components` without dataset
285
- weight_quantization_config .dataset = None
292
+ quantization_config .dataset = None
286
293
else :
287
294
unet = cls .load_model (unet_path , quantization_config )
288
295
@@ -294,12 +301,12 @@ def _from_pretrained(
294
301
}
295
302
296
303
for key , value in components .items ():
297
- components [key ] = cls .load_model (value , weight_quantization_config ) if value .is_file () else None
304
+ components [key ] = cls .load_model (value , quantization_config ) if value .is_file () else None
298
305
299
306
if model_save_dir is None :
300
307
model_save_dir = new_model_save_dir
301
308
302
- if quantization_config and quantization_config . dataset is not None :
309
+ if dataset is not None :
303
310
sd_model = cls (unet = unet , config = config , model_save_dir = model_save_dir , ** components , ** kwargs )
304
311
305
312
supported_pipelines = (
@@ -310,24 +317,13 @@ def _from_pretrained(
310
317
if not isinstance (sd_model , supported_pipelines ):
311
318
raise NotImplementedError (f"Quantization in hybrid mode is not supported for { cls .__name__ } " )
312
319
313
- num_inference_steps = 4 if isinstance (sd_model , OVLatentConsistencyModelPipeline ) else 50
314
320
nsamples = quantization_config .num_samples if quantization_config .num_samples else 200
315
- dataset = deepcopy (quantization_config .dataset )
316
-
317
- if isinstance (dataset , str ):
318
- from .quantization import get_stable_diffusion_dataset
319
-
320
- num_unet_runs = math .ceil (nsamples / num_inference_steps )
321
- dataset = get_stable_diffusion_dataset (dataset , num_unet_runs )
322
-
323
- unet_inputs = sd_model ._prepare_unet_inputs (dataset , nsamples , num_inference_steps )
321
+ unet_inputs = sd_model ._prepare_unet_inputs (dataset , nsamples )
324
322
325
323
from .quantization import _hybrid_quantization
326
324
327
- hybrid_quantization_config = deepcopy (quantization_config )
328
- hybrid_quantization_config .dataset = unet_inputs
329
- hybrid_quantization_config .num_samples = nsamples
330
- unet = _hybrid_quantization (sd_model .unet .model , hybrid_quantization_config )
325
+ unet = _hybrid_quantization (sd_model .unet .model , quantization_config , dataset = unet_inputs )
326
+ quantization_config .dataset = dataset
331
327
332
328
return cls (
333
329
unet = unet ,
@@ -340,21 +336,52 @@ def _from_pretrained(
340
336
341
337
def _prepare_unet_inputs (
342
338
self ,
343
- dataset : List [str ],
339
+ dataset : Union [str , List [ Any ] ],
344
340
num_samples : int ,
345
- num_inference_steps : int ,
346
341
height : Optional [int ] = 512 ,
347
342
width : Optional [int ] = 512 ,
343
+ seed : Optional [int ] = 42 ,
348
344
** kwargs ,
349
345
) -> Dict [str , Any ]:
350
346
self .compile ()
351
- calibration_data = []
347
+
348
+ if isinstance (dataset , str ):
349
+ dataset = deepcopy (dataset )
350
+ available_datasets = PREDEFINED_SD_DATASETS .keys ()
351
+ if dataset not in available_datasets :
352
+ raise ValueError (
353
+ f"""You have entered a string value for dataset. You can only choose between
354
+ { list (available_datasets )} , but the { dataset } was found"""
355
+ )
356
+
357
+ from datasets import load_dataset
358
+
359
+ dataset_metadata = PREDEFINED_SD_DATASETS [dataset ]
360
+ dataset = load_dataset (dataset , split = dataset_metadata ["split" ], streaming = True ).shuffle (seed = seed )
361
+ input_names = dataset_metadata ["inputs" ]
362
+ dataset = dataset .select_columns (list (input_names .values ()))
363
+
364
+ def transform_fn (data_item ):
365
+ return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
366
+
367
+ else :
368
+
369
+ def transform_fn (data_item ):
370
+ return data_item if isinstance (data_item , (list , dict )) else [data_item ]
352
371
353
372
from .quantization import InferRequestWrapper
354
373
374
+ calibration_data = []
355
375
self .unet .request = InferRequestWrapper (self .unet .request , calibration_data )
356
- for prompt in dataset :
357
- _ = self .__call__ (prompt , num_inference_steps = num_inference_steps , height = height , width = width )
376
+
377
+ for inputs in dataset :
378
+ inputs = transform_fn (inputs )
379
+ if isinstance (inputs , dict ):
380
+ self .__call__ (** inputs , height = height , width = width )
381
+ else :
382
+ self .__call__ (* inputs , height = height , width = width )
383
+ if len (calibration_data ) > num_samples :
384
+ break
358
385
359
386
self .unet .request = self .unet .request .request
360
387
return calibration_data [:num_samples ]
0 commit comments