@@ -662,7 +662,7 @@ def _prepare_speech_to_text_calibration_data(
662
662
return CalibrationDataset (calibration_data )
663
663
664
664
def _prepare_diffusion_calibration_data (
665
- self , config : OVQuantizationConfigBase , dataset : "Dataset"
665
+ self , config : OVQuantizationConfigBase , dataset : Union [ List , "Dataset" ]
666
666
) -> CalibrationDataset :
667
667
"""
668
668
Prepares calibration data for diffusion models by inferring it on a dataset. Currently, collects data only for
@@ -679,19 +679,24 @@ def _prepare_diffusion_calibration_data(
679
679
num_samples = config .num_samples or 200
680
680
calibration_data = []
681
681
try :
682
+ self .disable_progress_bar (disable = True )
682
683
diffuser .request = InferRequestWrapper (diffuser .request , calibration_data )
683
684
684
- for item in tqdm (dataset , desc = "Collecting calibration data" ):
685
+ pbar = tqdm (total = num_samples , desc = "Collecting calibration data" )
686
+ for item in dataset :
685
687
prompt = (
686
688
item [PREDEFINED_DIFFUSION_DATASETS [config .dataset ]["prompt_column_name" ]]
687
689
if isinstance (item , dict )
688
690
else item
689
691
)
690
692
self .model (prompt , height = height , width = width )
693
+ pbar .update (min (num_samples , len (calibration_data )) - pbar .n )
691
694
if len (calibration_data ) >= num_samples :
695
+ calibration_data = calibration_data [:num_samples ]
692
696
break
693
697
finally :
694
698
diffuser .request = diffuser .request .request
699
+ self .disable_progress_bar (disable = False )
695
700
696
701
return CalibrationDataset ({diffuser_model_name : nncf .Dataset (calibration_data [:num_samples ])})
697
702
@@ -700,3 +705,9 @@ def _remove_unused_columns(self, dataset: "Dataset"):
700
705
# for example there is model.generate()
701
706
ignored_columns = list (set (dataset .column_names ) - set (self ._signature_columns ))
702
707
return dataset .remove_columns (ignored_columns )
708
+
709
+ def disable_progress_bar (self , disable : bool = True ) -> None :
710
+ if not hasattr (self .model , "_progress_bar_config" ):
711
+ self .model ._progress_bar_config = {"disable" : disable }
712
+ else :
713
+ self .model ._progress_bar_config ["disable" ] = disable
0 commit comments