Skip to content

Commit 79a7ad8

Browse files
Fix calibration data collection progress bar
1 parent e24c693 commit 79a7ad8

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

optimum/intel/openvino/quantization/calibration_dataset_builder.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def _prepare_speech_to_text_calibration_data(
662662
return CalibrationDataset(calibration_data)
663663

664664
def _prepare_diffusion_calibration_data(
665-
self, config: OVQuantizationConfigBase, dataset: "Dataset"
665+
self, config: OVQuantizationConfigBase, dataset: Union[List, "Dataset"]
666666
) -> CalibrationDataset:
667667
"""
668668
Prepares calibration data for diffusion models by inferring it on a dataset. Currently, collects data only for
@@ -679,19 +679,24 @@ def _prepare_diffusion_calibration_data(
679679
num_samples = config.num_samples or 200
680680
calibration_data = []
681681
try:
682+
self.disable_progress_bar(disable=True)
682683
diffuser.request = InferRequestWrapper(diffuser.request, calibration_data)
683684

684-
for item in tqdm(dataset, desc="Collecting calibration data"):
685+
pbar = tqdm(total=num_samples, desc="Collecting calibration data")
686+
for item in dataset:
685687
prompt = (
686688
item[PREDEFINED_DIFFUSION_DATASETS[config.dataset]["prompt_column_name"]]
687689
if isinstance(item, dict)
688690
else item
689691
)
690692
self.model(prompt, height=height, width=width)
693+
pbar.update(min(num_samples, len(calibration_data)) - pbar.n)
691694
if len(calibration_data) >= num_samples:
695+
calibration_data = calibration_data[:num_samples]
692696
break
693697
finally:
694698
diffuser.request = diffuser.request.request
699+
self.disable_progress_bar(disable=False)
695700

696701
return CalibrationDataset({diffuser_model_name: nncf.Dataset(calibration_data[:num_samples])})
697702

@@ -700,3 +705,9 @@ def _remove_unused_columns(self, dataset: "Dataset"):
700705
# for example there is model.generate()
701706
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
702707
return dataset.remove_columns(ignored_columns)
708+
709+
def disable_progress_bar(self, disable: bool = True) -> None:
710+
if not hasattr(self.model, "_progress_bar_config"):
711+
self.model._progress_bar_config = {"disable": disable}
712+
else:
713+
self.model._progress_bar_config["disable"] = disable

0 commit comments

Comments
 (0)