Skip to content

Commit 4b582d2

Browse files
Polishing
1 parent 0b0d54c commit 4b582d2

File tree

4 files changed

+150
-43
lines changed

4 files changed

+150
-43
lines changed

optimum/commands/export/openvino.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,8 @@ def parse_args(parser: "ArgumentParser"):
313313
def run(self):
314314
from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers
315315
from ...exporters.openvino.utils import save_preprocessors
316-
from ...intel.openvino.quantization.configuration import (
317-
_DEFAULT_4BIT_CONFIG,
318-
OVConfig,
319-
get_default_int4_config,
320-
)
316+
from ...intel.openvino.quantization import OVConfig
317+
from ...intel.openvino.quantization.configuration import _DEFAULT_4BIT_CONFIG, get_default_int4_config
321318

322319
if self.args.library is None:
323320
# TODO: add revision, subfolder and token to args

optimum/intel/openvino/configuration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@
3333
logger = logging.getLogger(__name__)
3434

3535
logger.warning(
36-
"`optimum.intel.configuration` import path is deprecated and will be removed in optimum-intel v1.25."
36+
"`optimum.intel.configuration` import path is deprecated and will be removed in optimum-intel v1.24. "
3737
"Please use `optimum.intel.quantization.configuration` instead."
3838
)

optimum/intel/openvino/quantization/calibration_dataset_builder.py

+105-24
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import copy
1515
import inspect
1616
import logging
17-
from typing import Any, Callable, Dict, List, Optional, Sized, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import nncf
2020
import openvino
@@ -149,30 +149,74 @@ def __getattr__(self, attr):
149149

150150

151151
class OVCalibrationDatasetBuilder:
152-
def __init__(self, model: transformers.PreTrainedModel, seed: int = 42, **kwargs):
152+
"""
153+
A class to build calibration datasets for quantization with NNCF.
154+
155+
Allows to build a calibration dataset from:
156+
- a `datasets.Dataset` object
157+
- a name of the dataset from `datasets`
158+
- a quantization config object containing dataset specification
159+
160+
Returns calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
161+
`nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
162+
and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
163+
"""
164+
165+
def __init__(self, model: transformers.PreTrainedModel, seed: int = 42):
166+
"""
167+
168+
Args:
169+
model (`transformers.PreTrainedModel`):
170+
The model to build calibration dataset for.
171+
seed (`int`, defaults to 42):
172+
Random seed to use for reproducibility.
173+
"""
153174
self.model = model
154175
self.seed = seed
155-
# TODO: deprecate because model.forward() may not be the method which is called during inference, for example there is model.generate()
176+
# TODO: deprecate "signature_columns": model.forward() may not be the method which is called during inference,
177+
# for example there is model.generate()
156178
signature = inspect.signature(self.model.forward)
157179
self._signature_columns = list(signature.parameters.keys())
158180

159181
def build_from_dataset(
160182
self,
161183
quantization_config: OVQuantizationConfigBase,
162-
dataset: Union["Dataset", Sized],
184+
dataset: Union["Dataset", List],
163185
batch_size: Optional[int] = 1,
164186
data_collator: Optional[DataCollator] = None,
165187
remove_unused_columns: bool = False,
166188
) -> Dict[str, nncf.Dataset]:
167-
# TODO: deprecate remove_unused_columns ?
189+
"""
168190
191+
Args:
192+
quantization_config (`OVQuantizationConfigBase`):
193+
The quantization configuration object.
194+
dataset (`Union[datasets.Dataset, List]`):
195+
The dataset to collect calibration data from.
196+
batch_size (`int`, defaults to 1):
197+
The number of calibration samples to load per batch. Not always used.
198+
data_collator (`DataCollator`, *optional*):
199+
The function to use to form a batch from a list of elements of the calibration dataset. Not always used.
200+
remove_unused_columns (`bool`, defaults to `False`):
201+
Whether to remove the columns unused by the model forward method. Not always used.
202+
Returns:
203+
A calibration dataset in a form of a dictionary `Dict[str, nncf.Dataset]` containing an instance of
204+
`nncf.Dataset` for each model component. For example, for a sequence-to-sequence model with `encoder_model`
205+
and `decoder_model` components, the dictionary will contain two keys: `encoder_model` and `decoder_model`.
206+
"""
169207
from optimum.intel import OVModelForVisualCausalLM
170208
from optimum.intel.openvino.modeling_decoder import OVBaseDecoderModel
171209
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
172210

173211
if is_diffusers_available():
174212
from optimum.intel.openvino.modeling_diffusion import OVDiffusionPipeline
175213

214+
if isinstance(dataset, list):
215+
logger.warning(
216+
"Providing dataset as a list is deprecated and will be removed in optimum-intel v1.24. "
217+
"Please provide it as `datasets.Dataset`."
218+
)
219+
176220
if isinstance(self.model, (OVModelForVisualCausalLM, _OVModelForWhisper)) or (
177221
is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline)
178222
):
@@ -189,15 +233,17 @@ def build_from_dataset(
189233
elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
190234
return self._prepare_diffusion_calibration_data(quantization_config, dataset)
191235
else:
192-
# TODO
193-
raise Exception()
236+
raise RuntimeError("Unsupported model type for calibration dataset collection.")
194237
else:
195238
# Prepare from dataloader
196-
dataloader = self._get_calibration_dataloader(dataset, batch_size, data_collator, remove_unused_columns)
239+
# Setting `remove_unused_columns=True` until it is not deprecated
240+
dataloader = self._get_calibration_dataloader(
241+
dataset, batch_size, data_collator, remove_unused_columns=True
242+
)
197243
if isinstance(self.model, OVBaseDecoderModel):
198244
return self._prepare_decoder_calibration_data(quantization_config, dataloader)
199245
else:
200-
# Torch model quantization scenario
246+
# Assuming this is the torch model quantization scenario
201247
return {"model": nncf.Dataset(dataloader)}
202248

203249
def build_from_dataset_name(
@@ -221,6 +267,8 @@ def build_from_dataset_name(
221267
Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
222268
223269
Args:
270+
quantization_config (`OVQuantizationConfigBase`):
271+
The quantization configuration object.
224272
dataset_name (`str`):
225273
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
226274
in generic formats and optionally a dataset script, if it requires some code to read the data files.
@@ -243,10 +291,20 @@ def build_from_dataset_name(
243291
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
244292
should only be set to `True` for repositories you trust and in which you have read the code, as it will
245293
execute code present on the Hub on your local machine.
294+
streaming (`bool`, defaults to `False`):
295+
Whether to load dataset in streaming mode.
296+
batch_size (`int`, defaults to 1):
297+
The number of calibration samples to load per batch.
298+
data_collator (`DataCollator`, *optional*):
299+
The function to use to form a batch from a list of elements of the calibration dataset.
300+
remove_unused_columns (`bool`, defaults to `False`):
301+
Whether to remove the columns unused by the model forward method.
246302
Returns:
247303
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
248304
"""
249-
# TODO: deprecate remove_unused_columns ?
305+
306+
if remove_unused_columns:
307+
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24.")
250308

251309
dataset = self.load_dataset(
252310
dataset_name,
@@ -318,8 +376,9 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> Di
318376
elif isinstance(config.dataset, list) and all(isinstance(it, str) for it in config.dataset):
319377
dataset = config.dataset
320378
else:
321-
# TODO
322-
raise Exception()
379+
raise RuntimeError(
380+
"Please provide dataset as one of the accepted dataset labels or as a list of string prompts."
381+
)
323382

324383
return self.build_from_dataset(config, dataset)
325384

@@ -362,13 +421,13 @@ def load_dataset(
362421
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
363422
should only be set to `True` for repositories you trust and in which you have read the code, as it will
364423
execute code present on the Hub on your local machine.
424+
streaming (`bool`, defaults to `False`):
425+
Whether to load dataset in streaming mode.
365426
Returns:
366427
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
367428
"""
368-
# TODO: deprecate remove_unused_columns ?
369429
if not is_datasets_available():
370-
# TODO: update name
371-
raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset"))
430+
raise ValueError(DATASETS_IMPORT_ERROR.format("OVCalibrationDatasetBuilder.load_dataset"))
372431

373432
from datasets import load_dataset
374433

@@ -394,14 +453,19 @@ def load_dataset(
394453

395454
def _get_calibration_dataloader(
396455
self,
397-
dataset: Union["Dataset", Sized],
456+
dataset: Union["Dataset", List],
398457
batch_size: Optional[int] = 1,
399458
data_collator: Optional[DataCollator] = None,
400459
remove_unused_columns: bool = False,
401460
) -> OVDataLoader:
461+
"""
462+
Wrap dataset into a dataloader.
463+
"""
464+
if remove_unused_columns:
465+
logger.warning("`remove_unused_columns` is deprecated and will be removed in optimum-intel v1.24.")
466+
402467
if not is_datasets_available():
403-
# TODO: update name
404-
raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset"))
468+
raise ValueError(DATASETS_IMPORT_ERROR.format("OVCalibrationDatasetBuilder._get_calibration_dataloader"))
405469

406470
from datasets import Dataset, IterableDataset
407471

@@ -420,14 +484,12 @@ def _get_calibration_dataloader(
420484
)
421485
return OVDataLoader(dataloader)
422486

423-
def _remove_unused_columns(self, dataset: "Dataset"):
424-
# TODO: deprecate because model.forward() may not be the method which is called during inference, for example there is model.generate()
425-
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
426-
return dataset.remove_columns(ignored_columns)
427-
428487
def _prepare_decoder_calibration_data(
429488
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
430489
) -> Dict[str, nncf.Dataset]:
490+
"""
491+
Prepares calibration data by collecting model inputs during inference.
492+
"""
431493
# Prefetch past_key_values
432494
self.model.update_pkv_precision(True)
433495
self.model.compile()
@@ -446,8 +508,11 @@ def _prepare_decoder_calibration_data(
446508
return {"model": nncf.Dataset(collected_inputs)}
447509

448510
def _prepare_causal_lm_calibration_data(
449-
self, config: OVQuantizationConfigBase, seqlen: Optional[int] = 32
511+
self, config: OVQuantizationConfigBase, seqlen: int = 32
450512
) -> Dict[str, nncf.Dataset]:
513+
"""
514+
Prepares calibration data for causal language models. Relies on `optimum.gptq.data` module.
515+
"""
451516
from optimum.gptq.data import get_dataset, prepare_dataset
452517

453518
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=config.trust_remote_code)
@@ -470,6 +535,9 @@ def _prepare_causal_lm_calibration_data(
470535
def _prepare_visual_causal_lm_calibration_data(
471536
self, config: OVQuantizationConfigBase, dataset: "Dataset"
472537
) -> Dict[str, nncf.Dataset]:
538+
"""
539+
Prepares calibration data for VLM pipelines. Currently, collects data only for a language model component.
540+
"""
473541
processor = AutoProcessor.from_pretrained(config.processor, trust_remote_code=config.trust_remote_code)
474542
try:
475543
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=config.trust_remote_code)
@@ -520,6 +588,9 @@ def _prepare_visual_causal_lm_calibration_data(
520588
def _prepare_speech_to_text_calibration_data(
521589
self, config: OVQuantizationConfigBase, dataset: "Dataset"
522590
) -> Dict[str, nncf.Dataset]:
591+
"""
592+
Prepares calibration data for speech-to-text pipelines by inferring it on a dataset and collecting incurred inputs.
593+
"""
523594
from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder
524595

525596
models: Dict[str, Union[OVEncoder, OVDecoder]] = {}
@@ -558,6 +629,10 @@ def _prepare_speech_to_text_calibration_data(
558629
def _prepare_diffusion_calibration_data(
559630
self, config: OVQuantizationConfigBase, dataset: "Dataset"
560631
) -> Dict[str, nncf.Dataset]:
632+
"""
633+
Prepares calibration data for diffusion models by inferring it on a dataset. Currently, collects data only for
634+
a vision diffusion component.
635+
"""
561636
self.model.compile()
562637

563638
diffuser_model_name = "unet" if self.model.unet is not None else "transformer"
@@ -584,3 +659,9 @@ def _prepare_diffusion_calibration_data(
584659
diffuser.request = diffuser.request.request
585660

586661
return {diffuser_model_name: nncf.Dataset(calibration_data[:num_samples])}
662+
663+
def _remove_unused_columns(self, dataset: "Dataset"):
664+
# TODO: deprecate because model.forward() may not be the method which is called during inference,
665+
# for example there is model.generate()
666+
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
667+
return dataset.remove_columns(ignored_columns)

0 commit comments

Comments
 (0)