Skip to content

Commit 7a12f55

Browse files
Add CalibrationDataset class
1 parent 4b582d2 commit 7a12f55

File tree

4 files changed

+168
-90
lines changed

4 files changed

+168
-90
lines changed

optimum/intel/openvino/quantization/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@
2727

2828

2929
if is_nncf_available():
30-
# Quantization is possible only if nncf is installed
30+
# Running quantization is possible only if nncf is installed
31+
from .calibration_dataset_builder import CalibrationDataset
3132
from .quantizer import OVQuantizer

optimum/intel/openvino/quantization/calibration_dataset_builder.py

+59-24
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import copy
1515
import inspect
1616
import logging
17+
from collections import UserDict
1718
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1819

1920
import nncf
@@ -50,6 +51,32 @@
5051
logger = logging.getLogger(__name__)
5152

5253

54+
class CalibrationDataset(UserDict):
55+
"""
56+
A class to store calibration datasets for quantization with NNCF. Contains an instance of `nncf.Dataset` for each
57+
pipeline model component. For example, for a sequence-to-sequence pipeline with `encoder_model` and `decoder_model`
58+
components, the dictionary should contain two keys: `encoder_model` and `decoder_model`.
59+
"""
60+
61+
def __init__(self, calibration_dataset: Union[nncf.Dataset, Dict[str, nncf.Dataset]]):
62+
"""
63+
Args:
64+
calibration_dataset (`Union[nncf.Dataset, Dict[str, nncf.Dataset]]`):
65+
The calibration dataset to store. Can be a single `nncf.Dataset` instance or a dictionary containing
66+
`nncf.Dataset` instances for each model component. In the first case it is assumed that the dataset
67+
corresponds to a pipeline component named "model".
68+
"""
69+
if isinstance(calibration_dataset, nncf.Dataset):
70+
calibration_dataset = {"model": calibration_dataset}
71+
super().__init__(calibration_dataset)
72+
73+
def __getattr__(self, item: str):
74+
try:
75+
return self.data[item]
76+
except KeyError:
77+
raise AttributeError
78+
79+
5380
class OVDataLoader(PTInitializingDataLoader):
5481
def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
5582
return (), dataloader_output
@@ -157,9 +184,9 @@ class OVCalibrationDatasetBuilder:
157184
- a name of the dataset from `datasets`
158185
- a quantization config object containing dataset specification
159186
160-
Returns calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
161-
`nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
162-
and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
187+
Returns calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
188+
For example, for a sequence-to-sequence model with `encoder_model` and `decoder_model` components, the dictionary
189+
will contain two keys: `encoder_model` and `decoder_model`.
163190
"""
164191

165192
def __init__(self, model: transformers.PreTrainedModel, seed: int = 42):
@@ -185,7 +212,7 @@ def build_from_dataset(
185212
batch_size: Optional[int] = 1,
186213
data_collator: Optional[DataCollator] = None,
187214
remove_unused_columns: bool = False,
188-
) -> Dict[str, nncf.Dataset]:
215+
) -> CalibrationDataset:
189216
"""
190217
191218
Args:
@@ -200,9 +227,7 @@ def build_from_dataset(
200227
remove_unused_columns (`bool`, defaults to `False`):
201228
Whether to remove the columns unused by the model forward method. Not always used.
202229
Returns:
203-
A calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
204-
`nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
205-
and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
230+
A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
206231
"""
207232
from optimum.intel import OVModelForVisualCausalLM
208233
from optimum.intel.openvino.modeling_decoder import OVBaseDecoderModel
@@ -213,7 +238,7 @@ def build_from_dataset(
213238

214239
if isinstance(dataset, list):
215240
logger.warning(
216-
"Providing dataset as a list is deprecated and will be removed in optimum-intel v1.24. "
241+
"Providing dataset as a list is deprecated and will be removed in optimum-intel v1.25. "
217242
"Please provide it as `datasets.Dataset`."
218243
)
219244

@@ -244,7 +269,7 @@ def build_from_dataset(
244269
return self._prepare_decoder_calibration_data(quantization_config, dataloader)
245270
else:
246271
# Assuming this is the torch model quantization scenario
247-
return {"model": nncf.Dataset(dataloader)}
272+
return CalibrationDataset({"model": nncf.Dataset(dataloader)})
248273

249274
def build_from_dataset_name(
250275
self,
@@ -262,7 +287,7 @@ def build_from_dataset_name(
262287
batch_size: Optional[int] = 1,
263288
data_collator: Optional[DataCollator] = None,
264289
remove_unused_columns: bool = False,
265-
) -> Dict[str, nncf.Dataset]:
290+
) -> CalibrationDataset:
266291
"""
267292
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
268293
@@ -300,11 +325,11 @@ def build_from_dataset_name(
300325
remove_unused_columns (`bool`, defaults to `False`):
301326
Whether to remove the columns unused by the model forward method.
302327
Returns:
303-
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
328+
A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
304329
"""
305330

306331
if remove_unused_columns:
307-
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24.")
332+
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.25.")
308333

309334
dataset = self.load_dataset(
310335
dataset_name,
@@ -321,7 +346,17 @@ def build_from_dataset_name(
321346

322347
return self.build_from_dataset(quantization_config, dataset, batch_size, data_collator, remove_unused_columns)
323348

324-
def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> Dict[str, nncf.Dataset]:
349+
def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> CalibrationDataset:
350+
"""
351+
Builds a calibration dataset from a quantization config object. Namely, `quantization_config.dataset` property
352+
is used to infer dataset name.
353+
354+
Args:
355+
config (`OVQuantizationConfigBase`):
356+
The quantization configuration object.
357+
Returns:
358+
A calibration dataset as an instance of `CalibrationDataset` containing an `nncf.Dataset` for each model component.
359+
"""
325360
from optimum.intel import OVModelForCausalLM, OVModelForVisualCausalLM
326361
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
327362

@@ -462,7 +497,7 @@ def _get_calibration_dataloader(
462497
Wrap dataset into a dataloader.
463498
"""
464499
if remove_unused_columns:
465-
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24.")
500+
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.25.")
466501

467502
if not is_datasets_available():
468503
raise ValueError(DATASETS_IMPORT_ERROR.format("OVCalibrationDatasetBuilder._get_calibration_dataloader"))
@@ -486,7 +521,7 @@ def _get_calibration_dataloader(
486521

487522
def _prepare_decoder_calibration_data(
488523
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
489-
) -> Dict[str, nncf.Dataset]:
524+
) -> CalibrationDataset:
490525
"""
491526
Prepares calibration data by collecting model inputs during inference.
492527
"""
@@ -505,11 +540,11 @@ def _prepare_decoder_calibration_data(
505540
finally:
506541
self.model.request = self.model.request.request
507542

508-
return {"model": nncf.Dataset(collected_inputs)}
543+
return CalibrationDataset(nncf.Dataset(collected_inputs))
509544

510545
def _prepare_causal_lm_calibration_data(
511546
self, config: OVQuantizationConfigBase, seqlen: int = 32
512-
) -> Dict[str, nncf.Dataset]:
547+
) -> CalibrationDataset:
513548
"""
514549
Prepares calibration data for causal language models. Relies on `optimum.gptq.data` module.
515550
"""
@@ -530,11 +565,11 @@ def _prepare_causal_lm_calibration_data(
530565
calibration_dataset = prepare_dataset(calibration_dataset)
531566
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x))
532567

533-
return {"model": calibration_dataset}
568+
return CalibrationDataset(calibration_dataset)
534569

535570
def _prepare_visual_causal_lm_calibration_data(
536571
self, config: OVQuantizationConfigBase, dataset: "Dataset"
537-
) -> Dict[str, nncf.Dataset]:
572+
) -> CalibrationDataset:
538573
"""
539574
Prepares calibration data for VLM pipelines. Currently, collects data only for a language model component.
540575
"""
@@ -583,11 +618,11 @@ def _prepare_visual_causal_lm_calibration_data(
583618
if len(calibration_data) >= num_samples:
584619
break
585620

586-
return {"lm_model": nncf.Dataset(calibration_data)}
621+
return CalibrationDataset({"lm_model": nncf.Dataset(calibration_data)})
587622

588623
def _prepare_speech_to_text_calibration_data(
589624
self, config: OVQuantizationConfigBase, dataset: "Dataset"
590-
) -> Dict[str, nncf.Dataset]:
625+
) -> CalibrationDataset:
591626
"""
592627
Prepares calibration data for speech-to-text pipelines by inferring it on a dataset and collecting incurred inputs.
593628
"""
@@ -624,11 +659,11 @@ def _prepare_speech_to_text_calibration_data(
624659
calibration_data = {}
625660
for model_name, model_data in collected_inputs.items():
626661
calibration_data[f"{model_name}_model"] = nncf.Dataset(model_data)
627-
return calibration_data
662+
return CalibrationDataset(calibration_data)
628663

629664
def _prepare_diffusion_calibration_data(
630665
self, config: OVQuantizationConfigBase, dataset: "Dataset"
631-
) -> Dict[str, nncf.Dataset]:
666+
) -> CalibrationDataset:
632667
"""
633668
Prepares calibration data for diffusion models by inferring it on a dataset. Currently, collects data only for
634669
a vision diffusion component.
@@ -658,7 +693,7 @@ def _prepare_diffusion_calibration_data(
658693
finally:
659694
diffuser.request = diffuser.request.request
660695

661-
return {diffuser_model_name: nncf.Dataset(calibration_data[:num_samples])}
696+
return CalibrationDataset({diffuser_model_name: nncf.Dataset(calibration_data[:num_samples])})
662697

663698
def _remove_unused_columns(self, dataset: "Dataset"):
664699
# TODO: deprecate because model.forward() may not be the method which is called during inference,

0 commit comments

Comments
 (0)