Skip to content

Commit 1b960db

Browse files
WIP 2
1 parent f8da688 commit 1b960db

File tree

3 files changed

+132
-140
lines changed

3 files changed

+132
-140
lines changed

optimum/intel/openvino/quantization/dataset_builder.py

+124-133
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from tqdm import tqdm
3030
from transformers import DataCollator, default_data_collator, AutoTokenizer, AutoProcessor
3131

32-
from optimum.intel import is_accelerate_available, OVBaseDecoderModel, OVModelForCausalLM, OVModelForVisualCausalLM, \
32+
from optimum.intel import is_accelerate_available, OVModelForCausalLM, OVModelForVisualCausalLM, \
3333
OVModelForSpeechSeq2Seq, OVDiffusionPipeline
34+
from optimum.intel.openvino.modeling_decoder import OVBaseDecoderModel
3435
from optimum.intel.openvino.quantization import OVQuantizationConfigBase
35-
from optimum.intel.openvino.utils import PREDEFINED_VISUAL_LM_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS
36+
from optimum.intel.openvino.utils import PREDEFINED_VISUAL_LM_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, \
37+
PREDEFINED_DIFFUSION_DATASETS
3638
from optimum.intel.utils.import_utils import is_datasets_available, DATASETS_IMPORT_ERROR, is_datasets_version
3739

3840
if is_datasets_available():
@@ -149,8 +151,8 @@ def __init__(self, model: transformers.PreTrainedModel, seed: int = 42, **kwargs
149151

150152
def build_from_dataset(
151153
self,
154+
quantization_config: OVQuantizationConfigBase,
152155
dataset: Union["Dataset", Sized],
153-
num_samples: Optional[int],
154156
batch_size: Optional[int] = 1,
155157
data_collator: Optional[DataCollator] = None,
156158
remove_unused_columns: bool = False,
@@ -160,16 +162,19 @@ def build_from_dataset(
160162
dataloader = self._get_calibration_dataloader(dataset, batch_size, data_collator, remove_unused_columns)
161163

162164
if isinstance(self.model, OVBaseDecoderModel):
163-
calibration_datasets = self._prepare_decoder_calibration_data(dataloader, num_samples)
165+
return self._prepare_decoder_calibration_data(dataloader, quantization_config.num_samples)
166+
elif isinstance(self.model, OVModelForVisualCausalLM):
167+
return self._prepare_visual_causal_lm_calibration_data(dataloader)
168+
elif isinstance(self.model, OVModelForSpeechSeq2Seq):
169+
return self._prepare_speech_to_text_calibration_data(dataloader, quantization_config.num_samples)
164170
elif isinstance(self.model, OVDiffusionPipeline):
165-
calibration_datasets = self._prepare_diffusion_calibration_data(dataloader=dataloader, num_samples=num_samples)
171+
return self._prepare_diffusion_calibration_data(dataloader=dataloader, num_samples=quantization_config.num_samples)
166172
else:
167173
raise Exception
168174

169-
return calibration_datasets
170-
171175
def build_from_dataset_name(
172176
self,
177+
quantization_config: OVQuantizationConfigBase,
173178
dataset_name: str,
174179
num_samples: int = 100,
175180
dataset_config_name: Optional[str] = None,
@@ -228,29 +233,103 @@ def build_from_dataset_name(
228233
streaming,
229234
)
230235

231-
return self.build_from_dataset(dataset, batch_size, data_collator, remove_unused_columns)
236+
return self.build_from_dataset(quantization_config, dataset, batch_size, data_collator, remove_unused_columns)
232237

233-
def build_from_quantization_config(
234-
self,
235-
quantization_config: OVQuantizationConfigBase,
236-
) -> Dict[str, nncf.Dataset]:
238+
def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> Dict[str, nncf.Dataset]:
237239
if isinstance(self, OVModelForCausalLM):
238-
return self._prepare_causal_lm_calibration_data(self, quantization_config)
240+
return self._prepare_causal_lm_calibration_data(self, config)
239241
elif isinstance(self, (OVModelForVisualCausalLM, OVModelForSpeechSeq2Seq)):
240-
if quantization_config.processor is None:
242+
if config.processor is None:
241243
raise ValueError(
242244
"`processor` must be specified in order to run data-aware quantization. Please provide it as a"
243245
"model id, or a path to a directory containing all the required configuration files."
244246
)
245247

248+
trc = config.trust_remote_code
249+
processor = AutoProcessor.from_pretrained(config.processor, trust_remote_code=trc)
246250
if isinstance(self, OVModelForVisualCausalLM):
247-
return self._prepare_visual_causal_lm_calibration_data(self, quantization_config)
248-
elif isinstance(self, OVModelForSpeechSeq2Seq):
249-
return self._prepare_speech_to_text_calibration_data(self, quantization_config)
251+
try:
252+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=trc)
253+
tokenizer_error = None
254+
except Exception as tokenizer_error: # noqa: F841
255+
tokenizer = None
256+
257+
dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS[config.dataset]
258+
259+
def preprocess_function(item):
260+
inputs_metadata = dataset_metadata["inputs"]
261+
instruction = item[inputs_metadata["instruction"]]
262+
image_url = item[inputs_metadata["image_url"]]
263+
264+
image = Image.open(requests.get(image_url, stream=True).raw)
265+
266+
try:
267+
inputs = self.model.preprocess_inputs(
268+
text=instruction, image=image, processor=processor, tokenizer=tokenizer,
269+
config=self.model.config
270+
)
271+
except ValueError as value_error:
272+
if "Tokenizer is required." in str(value_error) and tokenizer_error is not None:
273+
raise tokenizer_error
274+
raise value_error
275+
276+
return inputs
277+
278+
return self.build_from_dataset_name(
279+
config,
280+
config.dataset,
281+
config.num_samples or 32,
282+
dataset_split=dataset_metadata["split"],
283+
preprocess_function=preprocess_function,
284+
trust_remote_code=trc,
285+
)
286+
elif isinstance(self.model, OVModelForSpeechSeq2Seq):
287+
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset]
288+
289+
def preprocess_function(item):
290+
audio = item
291+
for key_name in dataset_metadata["inputs"]["audio"]:
292+
audio = audio[key_name]
293+
294+
sampling_rate = item
295+
for key_name in dataset_metadata["inputs"]["sampling_rate"]:
296+
sampling_rate = sampling_rate[key_name]
297+
298+
input_features = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features
299+
300+
return input_features
301+
302+
return self.build_from_dataset_name(
303+
config,
304+
dataset_metadata["id"],
305+
config.num_samples or 128,
306+
dataset_metadata["name"],
307+
dataset_metadata["split"],
308+
preprocess_function=preprocess_function,
309+
trust_remote_code=trc,
310+
streaming=dataset_metadata["streaming"],
311+
)
312+
else:
313+
raise Exception
250314
elif isinstance(self, OVDiffusionPipeline):
251-
dataset = quantization_config.dataset
315+
dataset = config.dataset
252316
if isinstance(dataset, str):
253-
return self._prepare_diffusion_calibration_data(self, dataset_name=quantization_config.dataset, num_samples=quantization_config.num_samples)
317+
dataset_name = dataset
318+
dataset_metadata = PREDEFINED_DIFFUSION_DATASETS[dataset_name]
319+
320+
def preprocess_function(item):
321+
return {inp_name: item[column] for inp_name, column in dataset_metadata["inputs"].items()}
322+
323+
dataset = self._load_dataset(
324+
dataset_name,
325+
dataset_split=dataset_metadata["split"],
326+
preprocess_function=preprocess_function,
327+
streaming=dataset_metadata["streaming"],
328+
)
329+
elif not(isinstance(dataset, list) and all(isinstance(it, str) for it in dataset)):
330+
raise Exception
331+
332+
return self.build_from_dataset(config, dataset)
254333

255334
def _load_dataset(
256335
self,
@@ -306,10 +385,10 @@ def _load_dataset(
306385
datasets_kwargs["trust_remote_code"] = trust_remote_code
307386

308387
dataset = load_dataset(dataset_name, **datasets_kwargs)
388+
dataset = dataset.shuffle(seed=self.seed)
309389

310390
if num_samples is not None:
311-
num_samples = min(num_samples, len(dataset))
312-
dataset = dataset.shuffle(seed=self.seed).select(range(num_samples))
391+
dataset = dataset.select(range(min(num_samples, len(dataset))))
313392

314393
if preprocess_function is not None:
315394
dataset = dataset.map(preprocess_function, batched=preprocess_batch)
@@ -347,7 +426,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
347426
return dataset.remove_columns(ignored_columns)
348427

349428
def _prepare_decoder_calibration_data(
350-
self, dataloader: OVDataLoader, num_samples: Optional[int] = 200
429+
self, dataloader: OVDataLoader, num_samples: int = 200
351430
) -> Dict[str, nncf.Dataset]:
352431
# Prefetch past_key_values
353432
self.model.update_pkv_precision(True)
@@ -385,43 +464,9 @@ def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase,
385464

386465
return {"model": calibration_dataset}
387466

388-
def _prepare_visual_causal_lm_calibration_data(self, config: OVQuantizationConfigBase) -> Dict[str, nncf.Dataset]:
389-
dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS[config.dataset]
390-
391-
def preprocess_function(item):
392-
inputs_metadata = dataset_metadata["inputs"]
393-
return item[inputs_metadata["instruction"]], item[inputs_metadata["image_url"]]
394-
395-
num_samples = config.num_samples or 32
396-
dataset = self._load_dataset(
397-
config.dataset,
398-
num_samples,
399-
dataset_split=dataset_metadata["split"],
400-
preprocess_function=preprocess_function,
401-
trust_remote_code=config.trust_remote_code,
402-
)
403-
dataloader = self._get_calibration_dataloader(dataset)
404-
405-
processor = AutoProcessor.from_pretrained(config.processor, trust_remote_code=config.trust_remote_code)
406-
try:
407-
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=config.trust_remote_code)
408-
tokenizer_error = None
409-
except Exception as tokenizer_error: # noqa: F841
410-
tokenizer = None
411-
412-
calibration_dataset = []
413-
for instruction, image_url in tqdm(dataloader, desc="Collecting calibration dataset", total=num_samples):
414-
image = Image.open(requests.get(image_url, stream=True).raw)
415-
416-
try:
417-
inputs = self.model.preprocess_inputs(
418-
text=instruction, image=image, processor=processor, tokenizer=tokenizer, config=self.model.config
419-
)
420-
except ValueError as value_error:
421-
if "Tokenizer is required." in str(value_error) and tokenizer_error is not None:
422-
raise tokenizer_error
423-
raise value_error
424-
467+
def _prepare_visual_causal_lm_calibration_data(self, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
468+
calibration_data = []
469+
for inputs in tqdm(dataloader, desc="Collecting calibration dataset"):
425470
input_ids = inputs.get("input_ids")
426471
position_ids = torch.arange(input_ids.size(1)).unsqueeze(0).to(input_ids.device)
427472

@@ -437,35 +482,11 @@ def preprocess_function(item):
437482
inputs_embeds=inputs_embeds,
438483
)
439484

440-
calibration_dataset.append(language_model_inputs)
441-
442-
return {"language_model": nncf.Dataset(calibration_dataset)}
443-
444-
def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigBase) -> Dict[str, nncf.Dataset]:
445-
dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset]
446-
447-
def preprocess_function(item):
448-
audio = item
449-
for key_name in dataset_metadata["inputs"]["audio"]:
450-
audio = audio[key_name]
451-
452-
sampling_rate = item
453-
for key_name in dataset_metadata["inputs"]["sampling_rate"]:
454-
sampling_rate = sampling_rate[key_name]
485+
calibration_data.append(language_model_inputs)
455486

456-
return audio, sampling_rate
457-
458-
num_samples = config.num_samples or 128
459-
dataloader = self._get_calibration_dataloader(
460-
dataset_metadata["id"],
461-
num_samples,
462-
dataset_metadata["name"],
463-
dataset_metadata["split"],
464-
preprocess_function=preprocess_function,
465-
trust_remote_code=config.trust_remote_code,
466-
streaming=True,
467-
)
487+
return {"language_model": nncf.Dataset(calibration_data)}
468488

489+
def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num_samples: int) -> Dict[str, nncf.Dataset]:
469490
encoder_calibration_data = []
470491
encoder_model = self.model.encoder
471492
encoder_model._compile()
@@ -489,13 +510,11 @@ def preprocess_function(item):
489510
decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True
490511
)
491512

492-
processor = AutoProcessor.from_pretrained(config.processor)
493513
try:
494514
# Download audio inputs beforehand to avoid possible connection issues
495515
audio_inputs = list(tqdm(dataloader, desc="Downloading audio inputs", total=num_samples))
496516

497-
for audio, sampling_rate in tqdm(audio_inputs, desc="Collecting calibration data"):
498-
input_features = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features
517+
for input_features in tqdm(audio_inputs, desc="Collecting calibration data"):
499518
self.model.generate(input_features)
500519
finally:
501520
encoder_model.request = encoder_model.request.request
@@ -512,68 +531,40 @@ def preprocess_function(item):
512531
return datasets
513532

514533
def _prepare_diffusion_calibration_data(
515-
self,
516-
dataloader: Optional[OVDataLoader] = None,
517-
dataset_name: Optional[str] = None,
518-
num_samples: Optional[int] = None,
534+
self, dataloader: OVDataLoader, num_samples: int = 200
519535
) -> Dict[str, nncf.Dataset]:
520536
self.model.compile()
521537

522-
diffuser = self.model.unet if self.model.unet is not None else self.model.transformer
538+
diffuser_model_name = "unet" if self.model.unet is not None else "transformer"
539+
diffuser = getattr(self, diffuser_model_name)
523540

524541
size = diffuser.config.get("sample_size", 64) * self.model.vae_scale_factor
525542
height, width = 2 * (min(size, 512),)
526-
num_samples = num_samples or 200
527-
528-
if dataset is not None:
529-
if isinstance(dataset, nncf.Dataset):
530-
return dataset
531-
if is_datasets_available() and isinstance(dataset, Dataset):
532-
dataset = dataset.select_columns(["caption"])
533543

534-
def transform_fn(data_item):
535-
return data_item if isinstance(data_item, (list, dict)) else [data_item]
536-
537-
elif isinstance(dataset_name, str):
538-
available_datasets = PREDEFINED_SD_DATASETS.keys()
539-
if dataset_name not in available_datasets:
540-
raise ValueError(
541-
f"""You have entered a string value for dataset. You can only choose between
542-
{list(available_datasets)}, but the {dataset_name} was found"""
543-
)
544-
545-
from datasets import load_dataset
546-
547-
dataset_metadata = PREDEFINED_SD_DATASETS[dataset_name]
548-
datasets_kwargs = {"split": dataset_metadata["split"], "streaming": True}
549-
dataset = load_dataset(dataset_name, **datasets_kwargs).shuffle(seed=self.seed)
550-
551-
input_names = dataset_metadata["inputs"]
552-
dataset = dataset.select_columns(list(input_names.values()))
553-
554-
def transform_fn(data_item):
555-
return {inp_name: data_item[column] for inp_name, column in input_names.items()}
556-
557-
else:
558-
raise ValueError(
559-
"For UNet inputs collection either quantization_config.dataset or custom "
560-
"calibration_dataset must be provided."
561-
)
544+
# TODO: move the logic below to ov_quantizer
545+
# if dataset is not None:
546+
# if isinstance(dataset, nncf.Dataset):
547+
# return dataset
548+
# if is_datasets_available() and isinstance(dataset, Dataset):
549+
# dataset = dataset.select_columns(["caption"])
550+
#
551+
# def transform_fn(data_item):
552+
# return data_item if isinstance(data_item, (list, dict)) else [data_item]
562553

563554
calibration_data = []
564555
try:
565556
diffuser.request = InferRequestWrapper(diffuser.request, calibration_data)
566557

567-
for inputs in dataset:
568-
inputs = transform_fn(inputs)
558+
for inputs in tqdm(dataloader, desc="Collecting calibration data"):
569559
if isinstance(inputs, dict):
570560
self.model(**inputs, height=height, width=width)
561+
elif isinstance(inputs, str):
562+
self.model(inputs, height=height, width=width)
571563
else:
572564
self.model(*inputs, height=height, width=width)
573565
if len(calibration_data) >= num_samples:
574566
break
575567
finally:
576568
diffuser.request = diffuser.request.request
577569

578-
calibration_dataset = nncf.Dataset(calibration_data[:num_samples])
579-
return calibration_dataset
570+
return {diffuser_model_name: nncf.Dataset(calibration_data[:num_samples])}

0 commit comments

Comments
 (0)