12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import collections .abc
15
16
import copy
16
17
import inspect
17
18
import logging
49
50
from ..utils .constant import _TASK_ALIASES
50
51
from ..utils .import_utils import DATASETS_IMPORT_ERROR , is_datasets_available
51
52
from ..utils .modeling_utils import get_model_device
52
- from .configuration import OVConfig , OVQuantizationConfig , OVWeightQuantizationConfig
53
+ from .configuration import OVConfig , OVQuantizationConfig , OVWeightQuantizationConfig , OVQuantizationMethod
53
54
from .modeling_base import OVBaseModel
54
55
from .utils import (
55
56
MAX_ONNX_OPSET ,
56
57
MIN_ONNX_QDQ_OPSET ,
57
58
ONNX_WEIGHTS_NAME ,
58
59
OV_XML_FILE_NAME ,
60
+ PREDEFINED_SD_DATASETS ,
59
61
)
60
62
61
63
@@ -201,7 +203,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs):
201
203
202
204
def quantize (
203
205
self ,
204
- calibration_dataset : Optional [Union [datasets . Dataset , nncf .Dataset , Iterable ]] = None ,
206
+ calibration_dataset : Optional [Union [" Dataset" , nncf .Dataset , Iterable ]] = None ,
205
207
save_directory : Optional [Union [str , Path ]] = None ,
206
208
ov_config : OVConfig = None ,
207
209
file_name : Optional [str ] = None ,
@@ -325,74 +327,84 @@ def _quantize_ovbasemodel(
325
327
remove_unused_columns : bool = True ,
326
328
** kwargs ,
327
329
):
330
+ from optimum .intel .openvino .modeling_diffusion import OVStableDiffusionPipelineBase
331
+
328
332
if save_directory is not None :
329
333
save_directory = Path (save_directory )
330
334
save_directory .mkdir (parents = True , exist_ok = True )
331
-
332
335
quantization_config = ov_config .quantization_config
336
+
337
+ if calibration_dataset is not None :
338
+ # Process custom calibration dataset
339
+
340
+ if isinstance (self .model , OVStableDiffusionPipelineBase ):
341
+ calibration_dataset = self ._prepare_unet_dataset (
342
+ quantization_config .num_samples ,
343
+ dataset = calibration_dataset )
344
+ elif Dataset is not None and isinstance (calibration_dataset , Dataset ):
345
+ calibration_dataloader = self ._get_calibration_dataloader (
346
+ calibration_dataset = calibration_dataset ,
347
+ batch_size = batch_size ,
348
+ remove_unused_columns = remove_unused_columns ,
349
+ data_collator = data_collator ,
350
+ )
351
+
352
+ if self .model .export_feature == "text-generation" and self .model .use_cache :
353
+ calibration_dataset = self ._prepare_text_generation_dataset (
354
+ quantization_config , calibration_dataloader )
355
+ else :
356
+ calibration_dataset = nncf .Dataset (calibration_dataloader )
357
+ elif isinstance (calibration_dataset , collections .abc .Iterable ):
358
+ calibration_dataset = nncf .Dataset (calibration_dataset )
359
+ elif not isinstance (calibration_dataset , nncf .Dataset ):
360
+ raise ValueError ("`calibration_dataset` must be either an `Iterable` object or an instance of "
361
+ f"`nncf.Dataset` or `datasets.Dataset`. Found: { type (calibration_dataset )} ." )
362
+
333
363
if isinstance (quantization_config , OVWeightQuantizationConfig ):
364
+ if quantization_config .dataset is not None and calibration_dataset is not None :
365
+ logger .info (
366
+ "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only "
367
+ "quantization. Will rely on `calibration_dataset`."
368
+ )
369
+
334
370
if calibration_dataset is None and isinstance (quantization_config .dataset , str ):
335
371
from optimum .intel import OVModelForCausalLM
336
372
337
373
if isinstance (self .model , OVModelForCausalLM ):
338
- from optimum .gptq .data import get_dataset , prepare_dataset
339
-
340
- tokenizer = AutoTokenizer .from_pretrained (quantization_config .tokenizer )
341
- nsamples = quantization_config .num_samples if quantization_config .num_samples else 128
342
- calibration_dataset = get_dataset (
343
- quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples
344
- )
345
- calibration_dataset = prepare_dataset (calibration_dataset )
346
- calibration_dataset = nncf .Dataset (calibration_dataset , lambda x : self .model .prepare_inputs (** x ))
374
+ calibration_dataset = self ._prepare_gptq_dataset (quantization_config )
375
+ elif isinstance (self .model , OVStableDiffusionPipelineBase ):
376
+ calibration_dataset = self ._prepare_unet_dataset (
377
+ quantization_config .num_samples ,
378
+ dataset_name = quantization_config .dataset )
347
379
else :
348
380
raise ValueError (
349
381
f"Can't create weight compression calibration dataset from string for { type (self .model )} "
350
382
)
351
383
352
- _weight_only_quantization (self .model .model , quantization_config , calibration_dataset )
384
+ if quantization_config .quant_method == OVQuantizationMethod .HYBRID :
385
+ if calibration_dataset is None :
386
+ raise ValueError ("Calibration dataset is required to run hybrid quantization." )
387
+ if isinstance (self .model , OVStableDiffusionPipelineBase ):
388
+ self .model .unet .model = _hybrid_quantization (self .model .unet .model , quantization_config , calibration_dataset )
389
+ else :
390
+ self .model .model = _hybrid_quantization (self .model .model , quantization_config , calibration_dataset )
391
+ else :
392
+ _weight_only_quantization (self .model .model , quantization_config , calibration_dataset )
353
393
if save_directory is not None :
354
394
self .model .save_pretrained (save_directory )
355
395
ov_config .save_pretrained (save_directory )
356
396
return
397
+
357
398
if not isinstance (quantization_config , OVQuantizationConfig ):
358
399
raise ValueError (f"Unsupported type of quantization config: { type (quantization_config )} " )
359
400
360
- if isinstance (calibration_dataset , nncf .Dataset ):
361
- quantization_dataset = calibration_dataset
362
- elif Dataset is not None and isinstance (calibration_dataset , Dataset ):
363
- calibration_dataloader = self ._get_calibration_dataloader (
364
- calibration_dataset = calibration_dataset ,
365
- batch_size = batch_size ,
366
- remove_unused_columns = remove_unused_columns ,
367
- data_collator = data_collator ,
368
- )
369
-
370
- if self .model .export_feature == "text-generation" and self .model .use_cache :
371
- # Prefetch past_key_values
372
- self .model .update_pkv_precision (True )
373
- self .model .compile ()
374
- collected_inputs = []
375
-
376
- self .model .request = InferRequestWrapper (self .model .request , collected_inputs )
377
- try :
378
- for data in calibration_dataloader :
379
- self .model .generate (** data , max_new_tokens = 1 )
380
- if len (collected_inputs ) >= quantization_config .num_samples :
381
- break
382
- finally :
383
- self .model .request = self .model .request .request
384
- quantization_dataset = nncf .Dataset (collected_inputs )
385
- else :
386
- quantization_dataset = nncf .Dataset (calibration_dataloader )
387
- else :
388
- if calibration_dataset is None :
389
- raise ValueError ("Calibration dataset is required to run quantization." )
390
- quantization_dataset = nncf .Dataset (calibration_dataset )
401
+ if calibration_dataset is None :
402
+ raise ValueError ("Calibration dataset is required to run quantization." )
391
403
392
404
# Actual model quantization
393
405
quantized_model = nncf .quantize (
394
406
self .model .model ,
395
- quantization_dataset ,
407
+ calibration_dataset ,
396
408
subset_size = quantization_config .num_samples ,
397
409
ignored_scope = quantization_config .get_ignored_scope_instance (),
398
410
model_type = nncf .ModelType (quantization_config .model_type ),
@@ -655,6 +667,103 @@ def _remove_unused_columns(self, dataset: "Dataset"):
655
667
ignored_columns = list (set (dataset .column_names ) - set (self ._signature_columns ))
656
668
return dataset .remove_columns (ignored_columns )
657
669
670
+ def _prepare_gptq_dataset (self , quantization_config : OVWeightQuantizationConfig ):
671
+ from optimum .gptq .data import get_dataset , prepare_dataset
672
+
673
+ tokenizer = AutoTokenizer .from_pretrained (quantization_config .tokenizer )
674
+ nsamples = quantization_config .num_samples if quantization_config .num_samples else 128
675
+ calibration_dataset = get_dataset (
676
+ quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples
677
+ )
678
+ calibration_dataset = prepare_dataset (calibration_dataset )
679
+ calibration_dataset = nncf .Dataset (calibration_dataset , lambda x : self .model .prepare_inputs (** x ))
680
+
681
+ return calibration_dataset
682
+
683
+ def _prepare_text_generation_dataset (
684
+ self ,
685
+ quantization_config : OVQuantizationConfig ,
686
+ calibration_dataloader : OVDataLoader ) -> nncf .Dataset :
687
+ # TODO: this function is not covered by tests, remove if not relevant anymore or cover by tests otherwise
688
+
689
+ # Prefetch past_key_values
690
+ self .model .update_pkv_precision (True )
691
+ self .model .compile ()
692
+ collected_inputs = []
693
+
694
+ num_samples = quantization_config .num_samples or 200
695
+
696
+ self .model .request = InferRequestWrapper (self .model .model .request , collected_inputs )
697
+ try :
698
+ for data in calibration_dataloader :
699
+ self .model .generate (** data , max_new_tokens = 1 )
700
+ if len (collected_inputs ) >= num_samples :
701
+ break
702
+ finally :
703
+ self .model .model .request = self .model .model .request .request
704
+ calibration_dataset = nncf .Dataset (collected_inputs )
705
+ return calibration_dataset
706
+
707
+ def _prepare_unet_dataset (
708
+ self ,
709
+ num_samples : Optional [int ] = None ,
710
+ dataset_name : Optional [str ] = None ,
711
+ dataset : Optional [Union [Iterable , "Dataset" ]] = None ) -> nncf .Dataset :
712
+ self .model .compile ()
713
+
714
+ size = self .model .unet .config .get ("sample_size" , 64 ) * self .model .vae_scale_factor
715
+ height , width = 2 * (min (size , 512 ),)
716
+ num_samples = num_samples or 200
717
+
718
+ if dataset is not None :
719
+ if isinstance (dataset , nncf .Dataset ):
720
+ return dataset
721
+ if Dataset is not None and isinstance (dataset , Dataset ):
722
+ dataset = dataset .select_columns (["caption" ])
723
+
724
+ def transform_fn (data_item ):
725
+ return data_item if isinstance (data_item , (list , dict )) else [data_item ]
726
+
727
+ elif isinstance (dataset_name , str ):
728
+ available_datasets = PREDEFINED_SD_DATASETS .keys ()
729
+ if dataset_name not in available_datasets :
730
+ raise ValueError (
731
+ f"""You have entered a string value for dataset. You can only choose between
732
+ { list (available_datasets )} , but the { dataset_name } was found"""
733
+ )
734
+
735
+ from datasets import load_dataset
736
+
737
+ dataset_metadata = PREDEFINED_SD_DATASETS [dataset_name ]
738
+ dataset = load_dataset (dataset_name , split = dataset_metadata ["split" ], streaming = True ).shuffle (seed = self .seed )
739
+ input_names = dataset_metadata ["inputs" ]
740
+ dataset = dataset .select_columns (list (input_names .values ()))
741
+
742
+ def transform_fn (data_item ):
743
+ return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
744
+
745
+ else :
746
+ raise ValueError ("For UNet inputs collection either quantization_config.dataset or custom "
747
+ "calibration_dataset must be provided." )
748
+
749
+ calibration_data = []
750
+ try :
751
+ self .model .unet .request = InferRequestWrapper (self .model .unet .request , calibration_data )
752
+
753
+ for inputs in dataset :
754
+ inputs = transform_fn (inputs )
755
+ if isinstance (inputs , dict ):
756
+ self .model (** inputs , height = height , width = width )
757
+ else :
758
+ self .model (* inputs , height = height , width = width )
759
+ if len (calibration_data ) >= num_samples :
760
+ break
761
+ finally :
762
+ self .model .unet .request = self .model .unet .request .request
763
+
764
+ calibration_dataset = nncf .Dataset (calibration_data [:num_samples ])
765
+ return calibration_dataset
766
+
658
767
659
768
def _weight_only_quantization (
660
769
model : openvino .runtime .Model ,
@@ -665,11 +774,6 @@ def _weight_only_quantization(
665
774
if isinstance (config , dict ):
666
775
config = OVWeightQuantizationConfig .from_dict (quantization_config )
667
776
668
- if config .dataset is not None and calibration_dataset is not None :
669
- logger .info (
670
- "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only "
671
- "quantization. Will rely on `calibration_dataset`."
672
- )
673
777
dataset = None
674
778
if calibration_dataset is not None :
675
779
if Dataset is not None and isinstance (calibration_dataset , Dataset ):
@@ -752,7 +856,7 @@ def _collect_ops_with_weights(model):
752
856
753
857
754
858
def _hybrid_quantization (
755
- model : openvino .runtime .Model , quantization_config : OVWeightQuantizationConfig , dataset : Dict [ str , Any ]
859
+ model : openvino .runtime .Model , quantization_config : OVWeightQuantizationConfig , dataset : nncf . Dataset
756
860
) -> openvino .runtime .Model :
757
861
"""
758
862
Quantize a model in hybrid mode with NNCF which means that we quantize:
@@ -764,7 +868,7 @@ def _hybrid_quantization(
764
868
The OpenVINO Runtime model for applying hybrid quantization.
765
869
quantization_config (`OVWeightQuantizationConfig`):
766
870
The configuration containing the parameters related to quantization.
767
- dataset (`Dict[str, Any] `):
871
+ dataset (`nncf.Dataset `):
768
872
The dataset used for hybrid quantization.
769
873
Returns:
770
874
The OpenVINO Runtime model with applied hybrid quantization.
@@ -781,7 +885,7 @@ def _hybrid_quantization(
781
885
subset_size = quantization_config .num_samples if quantization_config .num_samples else 200
782
886
quantized_model = nncf .quantize (
783
887
model = compressed_model ,
784
- calibration_dataset = nncf . Dataset ( dataset ) ,
888
+ calibration_dataset = dataset ,
785
889
model_type = nncf .ModelType .TRANSFORMER ,
786
890
ignored_scope = ptq_ignored_scope ,
787
891
# SQ algo should be disabled for MatMul nodes because their weights are already compressed
0 commit comments