50
50
from ..utils .constant import _TASK_ALIASES
51
51
from ..utils .import_utils import DATASETS_IMPORT_ERROR , is_datasets_available
52
52
from ..utils .modeling_utils import get_model_device
53
- from .configuration import OVConfig , OVQuantizationConfig , OVWeightQuantizationConfig , OVQuantizationMethod
53
+ from .configuration import OVConfig , OVQuantizationConfig , OVQuantizationMethod , OVWeightQuantizationConfig
54
54
from .modeling_base import OVBaseModel
55
55
from .utils import (
56
56
MAX_ONNX_OPSET ,
@@ -339,8 +339,8 @@ def _quantize_ovbasemodel(
339
339
340
340
if isinstance (self .model , OVStableDiffusionPipelineBase ):
341
341
calibration_dataset = self ._prepare_unet_dataset (
342
- quantization_config .num_samples ,
343
- dataset = calibration_dataset )
342
+ quantization_config .num_samples , dataset = calibration_dataset
343
+ )
344
344
elif Dataset is not None and isinstance (calibration_dataset , Dataset ):
345
345
calibration_dataloader = self ._get_calibration_dataloader (
346
346
calibration_dataset = calibration_dataset ,
@@ -351,14 +351,17 @@ def _quantize_ovbasemodel(
351
351
352
352
if self .model .export_feature == "text-generation" and self .model .use_cache :
353
353
calibration_dataset = self ._prepare_text_generation_dataset (
354
- quantization_config , calibration_dataloader )
354
+ quantization_config , calibration_dataloader
355
+ )
355
356
else :
356
357
calibration_dataset = nncf .Dataset (calibration_dataloader )
357
358
elif isinstance (calibration_dataset , collections .abc .Iterable ):
358
359
calibration_dataset = nncf .Dataset (calibration_dataset )
359
360
elif not isinstance (calibration_dataset , nncf .Dataset ):
360
- raise ValueError ("`calibration_dataset` must be either an `Iterable` object or an instance of "
361
- f"`nncf.Dataset` or `datasets.Dataset`. Found: { type (calibration_dataset )} ." )
361
+ raise ValueError (
362
+ "`calibration_dataset` must be either an `Iterable` object or an instance of "
363
+ f"`nncf.Dataset` or `datasets.Dataset`. Found: { type (calibration_dataset )} ."
364
+ )
362
365
363
366
if isinstance (quantization_config , OVWeightQuantizationConfig ):
364
367
if quantization_config .dataset is not None and calibration_dataset is not None :
@@ -374,8 +377,8 @@ def _quantize_ovbasemodel(
374
377
calibration_dataset = self ._prepare_gptq_dataset (quantization_config )
375
378
elif isinstance (self .model , OVStableDiffusionPipelineBase ):
376
379
calibration_dataset = self ._prepare_unet_dataset (
377
- quantization_config .num_samples ,
378
- dataset_name = quantization_config . dataset )
380
+ quantization_config .num_samples , dataset_name = quantization_config . dataset
381
+ )
379
382
else :
380
383
raise ValueError (
381
384
f"Can't create weight compression calibration dataset from string for { type (self .model )} "
@@ -385,7 +388,9 @@ def _quantize_ovbasemodel(
385
388
if calibration_dataset is None :
386
389
raise ValueError ("Calibration dataset is required to run hybrid quantization." )
387
390
if isinstance (self .model , OVStableDiffusionPipelineBase ):
388
- self .model .unet .model = _hybrid_quantization (self .model .unet .model , quantization_config , calibration_dataset )
391
+ self .model .unet .model = _hybrid_quantization (
392
+ self .model .unet .model , quantization_config , calibration_dataset
393
+ )
389
394
else :
390
395
self .model .model = _hybrid_quantization (self .model .model , quantization_config , calibration_dataset )
391
396
else :
@@ -672,18 +677,15 @@ def _prepare_gptq_dataset(self, quantization_config: OVWeightQuantizationConfig)
672
677
673
678
tokenizer = AutoTokenizer .from_pretrained (quantization_config .tokenizer )
674
679
nsamples = quantization_config .num_samples if quantization_config .num_samples else 128
675
- calibration_dataset = get_dataset (
676
- quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples
677
- )
680
+ calibration_dataset = get_dataset (quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples )
678
681
calibration_dataset = prepare_dataset (calibration_dataset )
679
682
calibration_dataset = nncf .Dataset (calibration_dataset , lambda x : self .model .prepare_inputs (** x ))
680
683
681
684
return calibration_dataset
682
685
683
686
def _prepare_text_generation_dataset (
684
- self ,
685
- quantization_config : OVQuantizationConfig ,
686
- calibration_dataloader : OVDataLoader ) -> nncf .Dataset :
687
+ self , quantization_config : OVQuantizationConfig , calibration_dataloader : OVDataLoader
688
+ ) -> nncf .Dataset :
687
689
# TODO: this function is not covered by tests, remove if not relevant anymore or cover by tests otherwise
688
690
689
691
# Prefetch past_key_values
@@ -705,10 +707,11 @@ def _prepare_text_generation_dataset(
705
707
return calibration_dataset
706
708
707
709
def _prepare_unet_dataset (
708
- self ,
709
- num_samples : Optional [int ] = None ,
710
- dataset_name : Optional [str ] = None ,
711
- dataset : Optional [Union [Iterable , "Dataset" ]] = None ) -> nncf .Dataset :
710
+ self ,
711
+ num_samples : Optional [int ] = None ,
712
+ dataset_name : Optional [str ] = None ,
713
+ dataset : Optional [Union [Iterable , "Dataset" ]] = None ,
714
+ ) -> nncf .Dataset :
712
715
self .model .compile ()
713
716
714
717
size = self .model .unet .config .get ("sample_size" , 64 ) * self .model .vae_scale_factor
@@ -735,16 +738,20 @@ def transform_fn(data_item):
735
738
from datasets import load_dataset
736
739
737
740
dataset_metadata = PREDEFINED_SD_DATASETS [dataset_name ]
738
- dataset = load_dataset (dataset_name , split = dataset_metadata ["split" ], streaming = True ).shuffle (seed = self .seed )
741
+ dataset = load_dataset (dataset_name , split = dataset_metadata ["split" ], streaming = True ).shuffle (
742
+ seed = self .seed
743
+ )
739
744
input_names = dataset_metadata ["inputs" ]
740
745
dataset = dataset .select_columns (list (input_names .values ()))
741
746
742
747
def transform_fn (data_item ):
743
748
return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
744
749
745
750
else :
746
- raise ValueError ("For UNet inputs collection either quantization_config.dataset or custom "
747
- "calibration_dataset must be provided." )
751
+ raise ValueError (
752
+ "For UNet inputs collection either quantization_config.dataset or custom "
753
+ "calibration_dataset must be provided."
754
+ )
748
755
749
756
calibration_data = []
750
757
try :
0 commit comments