14
14
import copy
15
15
import inspect
16
16
import logging
17
+ from collections import UserDict
17
18
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
19
19
20
import nncf
50
51
logger = logging .getLogger (__name__ )
51
52
52
53
54
+ class CalibrationDataset (UserDict ):
55
+ """
56
+ A class to store calibration datasets for quantization with NNCF. Contains an instance of `nncf.Dataset` for each
57
+ pipeline model component. For example, for a sequence-to-sequence pipeline with `encoder_model` and `decoder_model`
58
+ components, the dictionary should contain two keys: `encoder_model` and `decoder_model`.
59
+ """
60
+
61
+ def __init__ (self , calibration_dataset : Union [nncf .Dataset , Dict [str , nncf .Dataset ]]):
62
+ """
63
+ Args:
64
+ calibration_dataset (`Union[nncf.Dataset, Dict[str, nncf.Dataset]]`):
65
+ The calibration dataset to store. Can be a single `nncf.Dataset` instance or a dictionary containing
66
+ `nncf.Dataset` instances for each model component. In the first case it is assumed that the dataset
67
+ corresponds to a pipeline component named "model".
68
+ """
69
+ if isinstance (calibration_dataset , nncf .Dataset ):
70
+ calibration_dataset = {"model" : calibration_dataset }
71
+ super ().__init__ (calibration_dataset )
72
+
73
+ def __getattr__ (self , item : str ):
74
+ try :
75
+ return self .data [item ]
76
+ except KeyError :
77
+ raise AttributeError
78
+
79
+
53
80
class OVDataLoader (PTInitializingDataLoader ):
54
81
def get_inputs (self , dataloader_output ) -> Tuple [Tuple , Dict ]:
55
82
return (), dataloader_output
@@ -157,9 +184,9 @@ class OVCalibrationDatasetBuilder:
157
184
- a name of the dataset from `datasets`
158
185
- a quantization config object containing dataset specification
159
186
160
- Returns calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
161
- `nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
162
- and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
187
+ Returns calibration dataset as an instance of `CalibrationDataset` containing an ` nncf.Dataset` for each model component.
188
+ For example, for a sequence-to-sequence model with `encoder_model` and `decoder_model` components, the dictionary
189
+ will contain two keys: `encoder_model` and `decoder_model`.
163
190
"""
164
191
165
192
def __init__ (self , model : transformers .PreTrainedModel , seed : int = 42 ):
@@ -185,7 +212,7 @@ def build_from_dataset(
185
212
batch_size : Optional [int ] = 1 ,
186
213
data_collator : Optional [DataCollator ] = None ,
187
214
remove_unused_columns : bool = False ,
188
- ) -> Dict [ str , nncf . Dataset ] :
215
+ ) -> CalibrationDataset :
189
216
"""
190
217
191
218
Args:
@@ -200,9 +227,7 @@ def build_from_dataset(
200
227
remove_unused_columns (`bool`, defaults to `False`):
201
228
Whether to remove the columns unused by the model forward method. Not always used.
202
229
Returns:
203
- A calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
204
- `nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
205
- and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
230
+ A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
206
231
"""
207
232
from optimum .intel import OVModelForVisualCausalLM
208
233
from optimum .intel .openvino .modeling_decoder import OVBaseDecoderModel
@@ -213,7 +238,7 @@ def build_from_dataset(
213
238
214
239
if isinstance (dataset , list ):
215
240
logger .warning (
216
- "Providing dataset as a list is deprecated and will be removed in optimum-intel v1.24 . "
241
+ "Providing dataset as a list is deprecated and will be removed in optimum-intel v1.25 . "
217
242
"Please provide it as `datasets.Dataset`."
218
243
)
219
244
@@ -244,7 +269,7 @@ def build_from_dataset(
244
269
return self ._prepare_decoder_calibration_data (quantization_config , dataloader )
245
270
else :
246
271
# Assuming this is the torch model quantization scenario
247
- return {"model" : nncf .Dataset (dataloader )}
272
+ return CalibrationDataset ( {"model" : nncf .Dataset (dataloader )})
248
273
249
274
def build_from_dataset_name (
250
275
self ,
@@ -262,7 +287,7 @@ def build_from_dataset_name(
262
287
batch_size : Optional [int ] = 1 ,
263
288
data_collator : Optional [DataCollator ] = None ,
264
289
remove_unused_columns : bool = False ,
265
- ) -> Dict [ str , nncf . Dataset ] :
290
+ ) -> CalibrationDataset :
266
291
"""
267
292
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
268
293
@@ -300,11 +325,11 @@ def build_from_dataset_name(
300
325
remove_unused_columns (`bool`, defaults to `False`):
301
326
Whether to remove the columns unused by the model forward method.
302
327
Returns:
303
- The calibration `datasets.Dataset` to use for the post-training static quantization calibration step .
328
+ A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component .
304
329
"""
305
330
306
331
if remove_unused_columns :
307
- logger .warning ("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24 ." )
332
+ logger .warning ("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.25 ." )
308
333
309
334
dataset = self .load_dataset (
310
335
dataset_name ,
@@ -321,7 +346,17 @@ def build_from_dataset_name(
321
346
322
347
return self .build_from_dataset (quantization_config , dataset , batch_size , data_collator , remove_unused_columns )
323
348
324
- def build_from_quantization_config (self , config : OVQuantizationConfigBase ) -> Dict [str , nncf .Dataset ]:
349
+ def build_from_quantization_config (self , config : OVQuantizationConfigBase ) -> CalibrationDataset :
350
+ """
351
+ Builds a calibration dataset from a quantization config object. Namely, `quantization_config.dataset` property
352
+ is used to infer dataset name.
353
+
354
+ Args:
355
+ config (`OVQuantizationConfigBase`):
356
+ The quantization configuration object.
357
+ Returns:
358
+ A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
359
+ """
325
360
from optimum .intel import OVModelForCausalLM , OVModelForVisualCausalLM
326
361
from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
327
362
@@ -462,7 +497,7 @@ def _get_calibration_dataloader(
462
497
Wrap dataset into a dataloader.
463
498
"""
464
499
if remove_unused_columns :
465
- logger .warning ("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24 ." )
500
+ logger .warning ("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.25 ." )
466
501
467
502
if not is_datasets_available ():
468
503
raise ValueError (DATASETS_IMPORT_ERROR .format ("OVCalibrationDatasetBuilder._get_calibration_dataloader" ))
@@ -486,7 +521,7 @@ def _get_calibration_dataloader(
486
521
487
522
def _prepare_decoder_calibration_data (
488
523
self , quantization_config : OVQuantizationConfigBase , dataloader : OVDataLoader
489
- ) -> Dict [ str , nncf . Dataset ] :
524
+ ) -> CalibrationDataset :
490
525
"""
491
526
Prepares calibration data by collecting model inputs during inference.
492
527
"""
@@ -505,11 +540,11 @@ def _prepare_decoder_calibration_data(
505
540
finally :
506
541
self .model .request = self .model .request .request
507
542
508
- return { "model" : nncf .Dataset (collected_inputs )}
543
+ return CalibrationDataset ( nncf .Dataset (collected_inputs ))
509
544
510
545
def _prepare_causal_lm_calibration_data (
511
546
self , config : OVQuantizationConfigBase , seqlen : int = 32
512
- ) -> Dict [ str , nncf . Dataset ] :
547
+ ) -> CalibrationDataset :
513
548
"""
514
549
Prepares calibration data for causal language models. Relies on `optimum.gptq.data` module.
515
550
"""
@@ -530,11 +565,11 @@ def _prepare_causal_lm_calibration_data(
530
565
calibration_dataset = prepare_dataset (calibration_dataset )
531
566
calibration_dataset = nncf .Dataset (calibration_dataset , lambda x : self .model .prepare_inputs (** x ))
532
567
533
- return { "model" : calibration_dataset }
568
+ return CalibrationDataset ( calibration_dataset )
534
569
535
570
def _prepare_visual_causal_lm_calibration_data (
536
571
self , config : OVQuantizationConfigBase , dataset : "Dataset"
537
- ) -> Dict [ str , nncf . Dataset ] :
572
+ ) -> CalibrationDataset :
538
573
"""
539
574
Prepares calibration data for VLM pipelines. Currently, collects data only for a language model component.
540
575
"""
@@ -583,11 +618,11 @@ def _prepare_visual_causal_lm_calibration_data(
583
618
if len (calibration_data ) >= num_samples :
584
619
break
585
620
586
- return {"lm_model" : nncf .Dataset (calibration_data )}
621
+ return CalibrationDataset ( {"lm_model" : nncf .Dataset (calibration_data )})
587
622
588
623
def _prepare_speech_to_text_calibration_data (
589
624
self , config : OVQuantizationConfigBase , dataset : "Dataset"
590
- ) -> Dict [ str , nncf . Dataset ] :
625
+ ) -> CalibrationDataset :
591
626
"""
592
627
Prepares calibration data for speech-to-text pipelines by inferring it on a dataset and collecting incurred inputs.
593
628
"""
@@ -624,11 +659,11 @@ def _prepare_speech_to_text_calibration_data(
624
659
calibration_data = {}
625
660
for model_name , model_data in collected_inputs .items ():
626
661
calibration_data [f"{ model_name } _model" ] = nncf .Dataset (model_data )
627
- return calibration_data
662
+ return CalibrationDataset ( calibration_data )
628
663
629
664
def _prepare_diffusion_calibration_data (
630
665
self , config : OVQuantizationConfigBase , dataset : "Dataset"
631
- ) -> Dict [ str , nncf . Dataset ] :
666
+ ) -> CalibrationDataset :
632
667
"""
633
668
Prepares calibration data for diffusion models by inferring it on a dataset. Currently, collects data only for
634
669
a vision diffusion component.
@@ -658,7 +693,7 @@ def _prepare_diffusion_calibration_data(
658
693
finally :
659
694
diffuser .request = diffuser .request .request
660
695
661
- return {diffuser_model_name : nncf .Dataset (calibration_data [:num_samples ])}
696
+ return CalibrationDataset ( {diffuser_model_name : nncf .Dataset (calibration_data [:num_samples ])})
662
697
663
698
def _remove_unused_columns (self , dataset : "Dataset" ):
664
699
# TODO: deprecate because model.forward() may not be the method which is called during inference,
0 commit comments