@@ -162,21 +162,20 @@ def build_from_dataset(
162
162
dataloader = self ._get_calibration_dataloader (dataset , batch_size , data_collator , remove_unused_columns )
163
163
164
164
if isinstance (self .model , OVBaseDecoderModel ):
165
- return self ._prepare_decoder_calibration_data (dataloader , quantization_config . num_samples )
165
+ return self ._prepare_decoder_calibration_data (quantization_config , dataloader )
166
166
elif isinstance (self .model , OVModelForVisualCausalLM ):
167
- return self ._prepare_visual_causal_lm_calibration_data (dataloader )
167
+ return self ._prepare_visual_causal_lm_calibration_data (quantization_config , dataloader )
168
168
elif isinstance (self .model , OVModelForSpeechSeq2Seq ):
169
- return self ._prepare_speech_to_text_calibration_data (dataloader , quantization_config . num_samples )
169
+ return self ._prepare_speech_to_text_calibration_data (quantization_config , dataloader )
170
170
elif isinstance (self .model , OVDiffusionPipeline ):
171
- return self ._prepare_diffusion_calibration_data (dataloader = dataloader , num_samples = quantization_config . num_samples )
171
+ return self ._prepare_diffusion_calibration_data (quantization_config , dataloader )
172
172
else :
173
173
raise Exception
174
174
175
175
def build_from_dataset_name (
176
176
self ,
177
177
quantization_config : OVQuantizationConfigBase ,
178
178
dataset_name : str ,
179
- num_samples : int = 100 ,
180
179
dataset_config_name : Optional [str ] = None ,
181
180
dataset_split : str = "train" ,
182
181
preprocess_function : Optional [Callable ] = None ,
@@ -196,8 +195,6 @@ def build_from_dataset_name(
196
195
dataset_name (`str`):
197
196
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
198
197
in generic formats and optionally a dataset script, if it requires some code to read the data files.
199
- num_samples (`int`, defaults to 100):
200
- The maximum number of samples composing the calibration dataset.
201
198
dataset_config_name (`str`, *optional*):
202
199
The name of the dataset configuration.
203
200
dataset_split (`str`, defaults to `"train"`):
@@ -222,7 +219,6 @@ def build_from_dataset_name(
222
219
223
220
dataset = self ._load_dataset (
224
221
dataset_name ,
225
- num_samples ,
226
222
dataset_config_name ,
227
223
dataset_split ,
228
224
preprocess_function ,
@@ -278,7 +274,6 @@ def preprocess_function(item):
278
274
return self .build_from_dataset_name (
279
275
config ,
280
276
config .dataset ,
281
- config .num_samples or 32 ,
282
277
dataset_split = dataset_metadata ["split" ],
283
278
preprocess_function = preprocess_function ,
284
279
trust_remote_code = trc ,
@@ -302,7 +297,6 @@ def preprocess_function(item):
302
297
return self .build_from_dataset_name (
303
298
config ,
304
299
dataset_metadata ["id" ],
305
- config .num_samples or 128 ,
306
300
dataset_metadata ["name" ],
307
301
dataset_metadata ["split" ],
308
302
preprocess_function = preprocess_function ,
@@ -334,7 +328,6 @@ def preprocess_function(item):
334
328
def _load_dataset (
335
329
self ,
336
330
dataset_name : str ,
337
- num_samples : int = 100 ,
338
331
dataset_config_name : Optional [str ] = None ,
339
332
dataset_split : str = "train" ,
340
333
preprocess_function : Optional [Callable ] = None ,
@@ -351,8 +344,6 @@ def _load_dataset(
351
344
dataset_name (`str`):
352
345
The dataset repository name on the Hugging Face Hub or path to a local directory containing data files
353
346
in generic formats and optionally a dataset script, if it requires some code to read the data files.
354
- num_samples (`int`, defaults to 100):
355
- The maximum number of samples composing the calibration dataset.
356
347
dataset_config_name (`str`, *optional*):
357
348
The name of the dataset configuration.
358
349
dataset_split (`str`, defaults to `"train"`):
@@ -387,9 +378,6 @@ def _load_dataset(
387
378
dataset = load_dataset (dataset_name , ** datasets_kwargs )
388
379
dataset = dataset .shuffle (seed = self .seed )
389
380
390
- if num_samples is not None :
391
- dataset = dataset .select (range (min (num_samples , len (dataset ))))
392
-
393
381
if preprocess_function is not None :
394
382
dataset = dataset .map (preprocess_function , batched = preprocess_batch )
395
383
@@ -426,16 +414,17 @@ def _remove_unused_columns(self, dataset: "Dataset"):
426
414
return dataset .remove_columns (ignored_columns )
427
415
428
416
def _prepare_decoder_calibration_data (
429
- self , dataloader : OVDataLoader , num_samples : int = 200
417
+ self , quantization_config : OVQuantizationConfigBase , dataloader : OVDataLoader
430
418
) -> Dict [str , nncf .Dataset ]:
431
419
# Prefetch past_key_values
432
420
self .model .update_pkv_precision (True )
433
421
self .model .compile ()
434
422
collected_inputs = []
435
423
424
+ num_samples = quantization_config .num_samples or 200
436
425
self .model .request = InferRequestWrapper (self .model .request , collected_inputs )
437
426
try :
438
- for data in dataloader :
427
+ for data in tqdm ( dataloader , desc = "Collecting calibration data" ) :
439
428
self .model .generate (** data , max_new_tokens = 1 )
440
429
if len (collected_inputs ) >= num_samples :
441
430
break
@@ -464,9 +453,10 @@ def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase,
464
453
465
454
return {"model" : calibration_dataset }
466
455
467
- def _prepare_visual_causal_lm_calibration_data (self , dataloader : OVDataLoader ) -> Dict [str , nncf .Dataset ]:
456
+ def _prepare_visual_causal_lm_calibration_data (self , quantization_config : OVQuantizationConfigBase , dataloader : OVDataLoader ) -> Dict [str , nncf .Dataset ]:
468
457
calibration_data = []
469
- for inputs in tqdm (dataloader , desc = "Collecting calibration dataset" ):
458
+ num_samples = quantization_config .num_samples or 32
459
+ for inputs in tqdm (dataloader , desc = "Collecting calibration dataset" , total = num_samples ):
470
460
input_ids = inputs .get ("input_ids" )
471
461
position_ids = torch .arange (input_ids .size (1 )).unsqueeze (0 ).to (input_ids .device )
472
462
@@ -484,9 +474,12 @@ def _prepare_visual_causal_lm_calibration_data(self, dataloader: OVDataLoader) -
484
474
485
475
calibration_data .append (language_model_inputs )
486
476
477
+ if len (calibration_data ) >= num_samples :
478
+ break
479
+
487
480
return {"language_model" : nncf .Dataset (calibration_data )}
488
481
489
- def _prepare_speech_to_text_calibration_data (self , dataloader : OVDataLoader , num_samples : int ) -> Dict [str , nncf .Dataset ]:
482
+ def _prepare_speech_to_text_calibration_data (self , quantization_config : OVQuantizationConfigBase , dataloader : OVDataLoader ) -> Dict [str , nncf .Dataset ]:
490
483
encoder_calibration_data = []
491
484
encoder_model = self .model .encoder
492
485
encoder_model ._compile ()
@@ -512,6 +505,7 @@ def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num
512
505
513
506
try :
514
507
# Download audio inputs beforehand to avoid possible connection issues
508
+ num_samples = quantization_config .num_samples or 32
515
509
audio_inputs = list (tqdm (dataloader , desc = "Downloading audio inputs" , total = num_samples ))
516
510
517
511
for input_features in tqdm (audio_inputs , desc = "Collecting calibration data" ):
@@ -531,7 +525,7 @@ def _prepare_speech_to_text_calibration_data(self, dataloader: OVDataLoader, num
531
525
return datasets
532
526
533
527
def _prepare_diffusion_calibration_data (
534
- self , dataloader : OVDataLoader , num_samples : int = 200
528
+ self , quantization_config : OVQuantizationConfigBase , dataloader : OVDataLoader
535
529
) -> Dict [str , nncf .Dataset ]:
536
530
self .model .compile ()
537
531
@@ -551,6 +545,7 @@ def _prepare_diffusion_calibration_data(
551
545
# def transform_fn(data_item):
552
546
# return data_item if isinstance(data_item, (list, dict)) else [data_item]
553
547
548
+ num_samples = quantization_config .num_samples or 200
554
549
calibration_data = []
555
550
try :
556
551
diffuser .request = InferRequestWrapper (diffuser .request , calibration_data )
0 commit comments