Skip to content

Commit a77da49

Browse files
WIP 4
1 parent bfa5014 commit a77da49

File tree

3 files changed

+207
-631
lines changed

3 files changed

+207
-631
lines changed

optimum/intel/openvino/quantization/dataset_builder.py

+57-37
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,47 @@
1414
import copy
1515
import inspect
1616
import logging
17-
import warnings
18-
from typing import Union, List, Any, Tuple, Dict, Optional, Iterable, Callable, Sized
17+
from typing import Any, Callable, Dict, List, Optional, Sized, Tuple, Union
1918

2019
import nncf
2120
import openvino
2221
import requests
2322
import torch
2423
import transformers
25-
from PIL.Image import Image
2624
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2725
from nncf.torch.initialization import PTInitializingDataLoader
28-
from torch.utils.data import RandomSampler, DataLoader
26+
from PIL.Image import Image
27+
from torch.utils.data import DataLoader, RandomSampler
2928
from tqdm import tqdm
30-
from transformers import DataCollator, default_data_collator, AutoTokenizer, AutoProcessor
29+
from transformers import AutoProcessor, AutoTokenizer, DataCollator, default_data_collator
3130

32-
from optimum.intel import is_accelerate_available, OVModelForCausalLM, OVModelForVisualCausalLM, \
33-
OVModelForSpeechSeq2Seq, OVDiffusionPipeline
31+
from optimum.intel import (
32+
OVModelForCausalLM,
33+
OVModelForVisualCausalLM,
34+
is_accelerate_available,
35+
)
3436
from optimum.intel.openvino.modeling_decoder import OVBaseDecoderModel
37+
from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper
3538
from optimum.intel.openvino.quantization import OVQuantizationConfigBase
36-
from optimum.intel.openvino.utils import PREDEFINED_VISUAL_LM_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, \
37-
PREDEFINED_DIFFUSION_DATASETS
38-
from optimum.intel.utils.import_utils import is_datasets_available, DATASETS_IMPORT_ERROR, is_datasets_version
39+
from optimum.intel.openvino.utils import (
40+
PREDEFINED_DIFFUSION_DATASETS,
41+
PREDEFINED_SPEECH_TO_TEXT_DATASETS,
42+
PREDEFINED_VISUAL_LM_DATASETS,
43+
)
44+
from optimum.intel.utils.import_utils import (
45+
DATASETS_IMPORT_ERROR,
46+
is_datasets_available,
47+
is_datasets_version,
48+
is_diffusers_available,
49+
)
50+
3951

4052
if is_datasets_available():
4153
from datasets import Dataset
4254

55+
if is_diffusers_available():
56+
from optimum.intel.openvino.modeling_diffusion import OVDiffusionPipeline
57+
4358
logger = logging.getLogger(__name__)
4459

4560

@@ -158,14 +173,14 @@ def build_from_dataset(
158173
remove_unused_columns: bool = False,
159174
) -> Dict[str, nncf.Dataset]:
160175
# TODO: deprecate remove_unused_columns ?
161-
176+
162177
dataloader = self._get_calibration_dataloader(dataset, batch_size, data_collator, remove_unused_columns)
163178

164179
if isinstance(self.model, OVBaseDecoderModel):
165180
return self._prepare_decoder_calibration_data(quantization_config, dataloader)
166181
elif isinstance(self.model, OVModelForVisualCausalLM):
167182
return self._prepare_visual_causal_lm_calibration_data(quantization_config, dataloader)
168-
elif isinstance(self.model, OVModelForSpeechSeq2Seq):
183+
elif isinstance(self.model, _OVModelForWhisper):
169184
return self._prepare_speech_to_text_calibration_data(quantization_config, dataloader)
170185
elif isinstance(self.model, OVDiffusionPipeline):
171186
return self._prepare_diffusion_calibration_data(quantization_config, dataloader)
@@ -216,8 +231,8 @@ def build_from_dataset_name(
216231
The calibration `datasets.Dataset` to use for the post-training static quantization calibration step.
217232
"""
218233
# TODO: deprecate remove_unused_columns ?
219-
220-
dataset = self._load_dataset(
234+
235+
dataset = self.load_dataset(
221236
dataset_name,
222237
dataset_config_name,
223238
dataset_split,
@@ -228,13 +243,13 @@ def build_from_dataset_name(
228243
trust_remote_code,
229244
streaming,
230245
)
231-
246+
232247
return self.build_from_dataset(quantization_config, dataset, batch_size, data_collator, remove_unused_columns)
233248

234249
def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> Dict[str, nncf.Dataset]:
235250
if isinstance(self, OVModelForCausalLM):
236251
return self._prepare_causal_lm_calibration_data(self, config)
237-
elif isinstance(self, (OVModelForVisualCausalLM, OVModelForSpeechSeq2Seq)):
252+
elif isinstance(self, (OVModelForVisualCausalLM, _OVModelForWhisper)):
238253
if config.processor is None:
239254
raise ValueError(
240255
"`processor` must be specified in order to run data-aware quantization. Please provide it as a"
@@ -261,8 +276,11 @@ def preprocess_function(item):
261276

262277
try:
263278
inputs = self.model.preprocess_inputs(
264-
text=instruction, image=image, processor=processor, tokenizer=tokenizer,
265-
config=self.model.config
279+
text=instruction,
280+
image=image,
281+
processor=processor,
282+
tokenizer=tokenizer,
283+
config=self.model.config,
266284
)
267285
except ValueError as value_error:
268286
if "Tokenizer is required." in str(value_error) and tokenizer_error is not None:
@@ -278,7 +296,7 @@ def preprocess_function(item):
278296
preprocess_function=preprocess_function,
279297
trust_remote_code=trc,
280298
)
281-
elif isinstance(self.model, OVModelForSpeechSeq2Seq):
299+
elif isinstance(self.model, _OVModelForWhisper):
282300
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset]
283301

284302
def preprocess_function(item):
@@ -320,12 +338,12 @@ def preprocess_function(item):
320338
preprocess_function=preprocess_function,
321339
streaming=dataset_metadata["streaming"],
322340
)
323-
elif not(isinstance(dataset, list) and all(isinstance(it, str) for it in dataset)):
341+
elif not (isinstance(dataset, list) and all(isinstance(it, str) for it in dataset)):
324342
raise Exception
325343

326344
return self.build_from_dataset(config, dataset)
327345

328-
def _load_dataset(
346+
def load_dataset(
329347
self,
330348
dataset_name: str,
331349
dataset_config_name: Optional[str] = None,
@@ -371,7 +389,13 @@ def _load_dataset(
371389

372390
from datasets import load_dataset
373391

374-
datasets_kwargs = {"name": dataset_config_name, "split": dataset_split, "token": token, "cache_dir": cache_dir, "streaming": streaming}
392+
datasets_kwargs = {
393+
"name": dataset_config_name,
394+
"split": dataset_split,
395+
"token": token,
396+
"cache_dir": cache_dir,
397+
"streaming": streaming,
398+
}
375399
if is_datasets_version(">=", "2.20.0"):
376400
datasets_kwargs["trust_remote_code"] = trust_remote_code
377401

@@ -393,7 +417,7 @@ def _get_calibration_dataloader(
393417
if not is_datasets_available():
394418
# TODO: update name
395419
raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer.get_calibration_dataset"))
396-
420+
397421
from datasets import Dataset
398422

399423
data_collator = data_collator if data_collator is not None else default_data_collator
@@ -433,7 +457,9 @@ def _prepare_decoder_calibration_data(
433457

434458
return {"model": nncf.Dataset(collected_inputs)}
435459

436-
def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase, seqlen: Optional[int] = 32) -> Dict[str, nncf.Dataset]:
460+
def _prepare_causal_lm_calibration_data(
461+
self, config: OVQuantizationConfigBase, seqlen: Optional[int] = 32
462+
) -> Dict[str, nncf.Dataset]:
437463
from optimum.gptq.data import get_dataset, prepare_dataset
438464

439465
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=config.trust_remote_code)
@@ -453,7 +479,9 @@ def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase,
453479

454480
return {"model": calibration_dataset}
455481

456-
def _prepare_visual_causal_lm_calibration_data(self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
482+
def _prepare_visual_causal_lm_calibration_data(
483+
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
484+
) -> Dict[str, nncf.Dataset]:
457485
calibration_data = []
458486
num_samples = quantization_config.num_samples or 32
459487
for inputs in tqdm(dataloader, desc="Collecting calibration dataset", total=num_samples):
@@ -477,9 +505,11 @@ def _prepare_visual_causal_lm_calibration_data(self, quantization_config: OVQuan
477505
if len(calibration_data) >= num_samples:
478506
break
479507

480-
return {"language_model": nncf.Dataset(calibration_data)}
508+
return {"lm_model": nncf.Dataset(calibration_data)}
481509

482-
def _prepare_speech_to_text_calibration_data(self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
510+
def _prepare_speech_to_text_calibration_data(
511+
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
512+
) -> Dict[str, nncf.Dataset]:
483513
encoder_calibration_data = []
484514
encoder_model = self.model.encoder
485515
encoder_model._compile()
@@ -535,16 +565,6 @@ def _prepare_diffusion_calibration_data(
535565
size = diffuser.config.get("sample_size", 64) * self.model.vae_scale_factor
536566
height, width = 2 * (min(size, 512),)
537567

538-
# TODO: move the logic below to ov_quantizer
539-
# if dataset is not None:
540-
# if isinstance(dataset, nncf.Dataset):
541-
# return dataset
542-
# if is_datasets_available() and isinstance(dataset, Dataset):
543-
# dataset = dataset.select_columns(["caption"])
544-
#
545-
# def transform_fn(data_item):
546-
# return data_item if isinstance(data_item, (list, dict)) else [data_item]
547-
548568
num_samples = quantization_config.num_samples or 200
549569
calibration_data = []
550570
try:

0 commit comments

Comments
 (0)