Skip to content

Commit 4adfe78

Browse files
WIP 3
1 parent 1b960db commit 4adfe78

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

optimum/intel/openvino/quantization/dataset_builder.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,20 @@ def build_from_dataset(
162162
dataloader = self._get_calibration_dataloader(dataset, batch_size, data_collator, remove_unused_columns)
163163

164164
if isinstance(self.model, OVBaseDecoderModel):
165-
return self._prepare_decoder_calibration_data(dataloader, quantization_config.num_samples)
165+
return self._prepare_decoder_calibration_data(quantization_config, dataloader)
166166
elif isinstance(self.model, OVModelForVisualCausalLM):
167-
return self._prepare_visual_causal_lm_calibration_data(dataloader)
167+
return self._prepare_visual_causal_lm_calibration_data(quantization_config, dataloader)
168168
elif isinstance(self.model, OVModelForSpeechSeq2Seq):
169-
return self._prepare_speech_to_text_calibration_data(dataloader, quantization_config.num_samples)
169+
return self._prepare_speech_to_text_calibration_data(quantization_config, dataloader)
170170
elif isinstance(self.model, OVDiffusionPipeline):
171-
return self._prepare_diffusion_calibration_data(dataloader=dataloader, num_samples=quantization_config.num_samples)
171+
return self._prepare_diffusion_calibration_data(quantization_config, dataloader)
172172
else:
173173
raise Exception
174174

175175
def build_from_dataset_name(
176176
self,
177177
quantization_config: OVQuantizationConfigBase,
178178
dataset_name: str,
179-
num_samples: int = 100,
180179
dataset_config_name: Optional[str] = None,
181180
dataset_split: str = "train",
182181
preprocess_function: Optional[Callable] = None,
@@ -196,8 +195,6 @@ def build_from_dataset_name(
196195
dataset_name (`str`):
197196
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
198197
in generic formats and optionally a dataset script, if it requires some code to read the data files.
199-
num_samples (`int`, defaults to 100):
200-
The maximum number of samples composing the calibration dataset.
201198
dataset_config_name (`str`, *optional*):
202199
The name of the dataset configuration.
203200
dataset_split (`str`, defaults to `"train"`):
@@ -222,7 +219,6 @@ def build_from_dataset_name(
222219

223220
dataset = self._load_dataset(
224221
dataset_name,
225-
num_samples,
226222
dataset_config_name,
227223
dataset_split,
228224
preprocess_function,
@@ -278,7 +274,6 @@ def preprocess_function(item):
278274
return self.build_from_dataset_name(
279275
config,
280276
config.dataset,
281-
config.num_samples or 32,
282277
dataset_split=dataset_metadata["split"],
283278
preprocess_function=preprocess_function,
284279
trust_remote_code=trc,
@@ -302,7 +297,6 @@ def preprocess_function(item):
302297
return self.build_from_dataset_name(
303298
config,
304299
dataset_metadata["id"],
305-
config.num_samples or 128,
306300
dataset_metadata["name"],
307301
dataset_metadata["split"],
308302
preprocess_function=preprocess_function,
@@ -334,7 +328,6 @@ def preprocess_function(item):
334328
def _load_dataset(
335329
self,
336330
dataset_name: str,
337-
num_samples: int = 100,
338331
dataset_config_name: Optional[str] = None,
339332
dataset_split: str = "train",
340333
preprocess_function: Optional[Callable] = None,
@@ -351,8 +344,6 @@ def _load_dataset(
351344
dataset_name (`str`):
352345
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
353346
in generic formats and optionally a dataset script, if it requires some code to read the data files.
354-
num_samples (`int`, defaults to 100):
355-
The maximum number of samples composing the calibration dataset.
356347
dataset_config_name (`str`, *optional*):
357348
The name of the dataset configuration.
358349
dataset_split (`str`, defaults to `"train"`):
@@ -387,9 +378,6 @@ def _load_dataset(
387378
dataset = load_dataset(dataset_name, **datasets_kwargs)
388379
dataset = dataset.shuffle(seed=self.seed)
389380

390-
if num_samples is not None:
391-
dataset = dataset.select(range(min(num_samples, len(dataset))))
392-
393381
if preprocess_function is not None:
394382
dataset = dataset.map(preprocess_function, batched=preprocess_batch)
395383

@@ -426,16 +414,17 @@ def _remove_unused_columns(self, dataset: "Dataset"):
426414
return dataset.remove_columns(ignored_columns)
427415

428416
def _prepare_decoder_calibration_data(
429-
self, dataloader: OVDataLoader, num_samples: int = 200
417+
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
430418
) -> Dict[str, nncf.Dataset]:
431419
# Prefetch past_key_values
432420
self.model.update_pkv_precision(True)
433421
self.model.compile()
434422
collected_inputs = []
435423

424+
num_samples = quantization_config.num_samples or 200
436425
self.model.request = InferRequestWrapper(self.model.request, collected_inputs)
437426
try:
438-
for data in dataloader:
427+
for data in tqdm(dataloader, desc="Collecting calibration data"):
439428
self.model.generate(**data, max_new_tokens=1)
440429
if len(collected_inputs) >= num_samples:
441430
break
@@ -464,9 +453,10 @@ def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase,
464453

465454
return {"model": calibration_dataset}
466455

467-
def _prepare_visual_causal_lm_calibration_data(self, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
456+
def _prepare_visual_causal_lm_calibration_data(self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
468457
calibration_data = []
469-
for inputs in tqdm(dataloader, desc="Collecting calibration dataset"):
458+
num_samples = quantization_config.num_samples or 32
459+
for inputs in tqdm(dataloader, desc="Collecting calibration dataset", total=num_samples):
470460
input_ids = inputs.get("input_ids")
471461
position_ids = torch.arange(input_ids.size(1)).unsqueeze(0).to(input_ids.device)
472462

@@ -484,9 +474,12 @@ def _prepare_visual_causal_lm_calibration_data(self, dataloader: OVDataLoader) -
484474

485475
calibration_data.append(language_model_inputs)
486476

477+
if len(calibration_data) >= num_samples:
478+
break
479+
487480
return {"language_model": nncf.Dataset(calibration_data)}
488481

489-
def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num_samples: int) -> Dict[str, nncf.Dataset]:
482+
def _prepare_speech_to_text_calibration_data(self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader) -> Dict[str, nncf.Dataset]:
490483
encoder_calibration_data = []
491484
encoder_model = self.model.encoder
492485
encoder_model._compile()
@@ -512,6 +505,7 @@ def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num
512505

513506
try:
514507
# Download audio inputs beforehand to avoid possible connection issues
508+
num_samples = quantization_config.num_samples or 32
515509
audio_inputs = list(tqdm(dataloader, desc="Downloading audio inputs", total=num_samples))
516510

517511
for input_features in tqdm(audio_inputs, desc="Collecting calibration data"):
@@ -531,7 +525,7 @@ def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num
531525
return datasets
532526

533527
def _prepare_diffusion_calibration_data(
534-
self, dataloader: OVDataLoader, num_samples: int = 200
528+
self, quantization_config: OVQuantizationConfigBase, dataloader: OVDataLoader
535529
) -> Dict[str, nncf.Dataset]:
536530
self.model.compile()
537531

@@ -551,6 +545,7 @@ def _prepare_diffusion_calibration_data(
551545
# def transform_fn(data_item):
552546
# return data_item if isinstance(data_item, (list, dict)) else [data_item]
553547

548+
num_samples = quantization_config.num_samples or 200
554549
calibration_data = []
555550
try:
556551
diffuser.request = InferRequestWrapper(diffuser.request, calibration_data)

0 commit comments

Comments
 (0)