29
29
from tqdm import tqdm
30
30
from transformers import DataCollator , default_data_collator , AutoTokenizer , AutoProcessor
31
31
32
- from optimum .intel import is_accelerate_available , OVBaseDecoderModel , OVModelForCausalLM , OVModelForVisualCausalLM , \
32
+ from optimum .intel import is_accelerate_available , OVModelForCausalLM , OVModelForVisualCausalLM , \
33
33
OVModelForSpeechSeq2Seq , OVDiffusionPipeline
34
+ from optimum .intel .openvino .modeling_decoder import OVBaseDecoderModel
34
35
from optimum .intel .openvino .quantization import OVQuantizationConfigBase
35
- from optimum .intel .openvino .utils import PREDEFINED_VISUAL_LM_DATASETS , PREDEFINED_SPEECH_TO_TEXT_DATASETS
36
+ from optimum .intel .openvino .utils import PREDEFINED_VISUAL_LM_DATASETS , PREDEFINED_SPEECH_TO_TEXT_DATASETS , \
37
+ PREDEFINED_DIFFUSION_DATASETS
36
38
from optimum .intel .utils .import_utils import is_datasets_available , DATASETS_IMPORT_ERROR , is_datasets_version
37
39
38
40
if is_datasets_available ():
@@ -149,8 +151,8 @@ def __init__(self, model: transformers.PreTrainedModel, seed: int = 42, **kwargs
149
151
150
152
def build_from_dataset (
151
153
self ,
154
+ quantization_config : OVQuantizationConfigBase ,
152
155
dataset : Union ["Dataset" , Sized ],
153
- num_samples : Optional [int ],
154
156
batch_size : Optional [int ] = 1 ,
155
157
data_collator : Optional [DataCollator ] = None ,
156
158
remove_unused_columns : bool = False ,
@@ -160,16 +162,19 @@ def build_from_dataset(
160
162
dataloader = self ._get_calibration_dataloader (dataset , batch_size , data_collator , remove_unused_columns )
161
163
162
164
if isinstance (self .model , OVBaseDecoderModel ):
163
- calibration_datasets = self ._prepare_decoder_calibration_data (dataloader , num_samples )
165
+ return self ._prepare_decoder_calibration_data (dataloader , quantization_config .num_samples )
166
+ elif isinstance (self .model , OVModelForVisualCausalLM ):
167
+ return self ._prepare_visual_causal_lm_calibration_data (dataloader )
168
+ elif isinstance (self .model , OVModelForSpeechSeq2Seq ):
169
+ return self ._prepare_speech_to_text_calibration_data (dataloader , quantization_config .num_samples )
164
170
elif isinstance (self .model , OVDiffusionPipeline ):
165
- calibration_datasets = self ._prepare_diffusion_calibration_data (dataloader = dataloader , num_samples = num_samples )
171
+ return self ._prepare_diffusion_calibration_data (dataloader = dataloader , num_samples = quantization_config . num_samples )
166
172
else :
167
173
raise Exception
168
174
169
- return calibration_datasets
170
-
171
175
def build_from_dataset_name (
172
176
self ,
177
+ quantization_config : OVQuantizationConfigBase ,
173
178
dataset_name : str ,
174
179
num_samples : int = 100 ,
175
180
dataset_config_name : Optional [str ] = None ,
@@ -228,29 +233,103 @@ def build_from_dataset_name(
228
233
streaming ,
229
234
)
230
235
231
- return self .build_from_dataset (dataset , batch_size , data_collator , remove_unused_columns )
236
+ return self .build_from_dataset (quantization_config , dataset , batch_size , data_collator , remove_unused_columns )
232
237
233
- def build_from_quantization_config (
234
- self ,
235
- quantization_config : OVQuantizationConfigBase ,
236
- ) -> Dict [str , nncf .Dataset ]:
238
+ def build_from_quantization_config (self , config : OVQuantizationConfigBase ) -> Dict [str , nncf .Dataset ]:
237
239
if isinstance (self , OVModelForCausalLM ):
238
- return self ._prepare_causal_lm_calibration_data (self , quantization_config )
240
+ return self ._prepare_causal_lm_calibration_data (self , config )
239
241
elif isinstance (self , (OVModelForVisualCausalLM , OVModelForSpeechSeq2Seq )):
240
- if quantization_config .processor is None :
242
+ if config .processor is None :
241
243
raise ValueError (
242
244
"`processor` must be specified in order to run data-aware quantization. Please provide it as a"
243
245
"model id, or a path to a directory containing all the required configuration files."
244
246
)
245
247
248
+ trc = config .trust_remote_code
249
+ processor = AutoProcessor .from_pretrained (config .processor , trust_remote_code = trc )
246
250
if isinstance (self , OVModelForVisualCausalLM ):
247
- return self ._prepare_visual_causal_lm_calibration_data (self , quantization_config )
248
- elif isinstance (self , OVModelForSpeechSeq2Seq ):
249
- return self ._prepare_speech_to_text_calibration_data (self , quantization_config )
251
+ try :
252
+ tokenizer = AutoTokenizer .from_pretrained (config .tokenizer , trust_remote_code = trc )
253
+ tokenizer_error = None
254
+ except Exception as tokenizer_error : # noqa: F841
255
+ tokenizer = None
256
+
257
+ dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS [config .dataset ]
258
+
259
+ def preprocess_function (item ):
260
+ inputs_metadata = dataset_metadata ["inputs" ]
261
+ instruction = item [inputs_metadata ["instruction" ]]
262
+ image_url = item [inputs_metadata ["image_url" ]]
263
+
264
+ image = Image .open (requests .get (image_url , stream = True ).raw )
265
+
266
+ try :
267
+ inputs = self .model .preprocess_inputs (
268
+ text = instruction , image = image , processor = processor , tokenizer = tokenizer ,
269
+ config = self .model .config
270
+ )
271
+ except ValueError as value_error :
272
+ if "Tokenizer is required." in str (value_error ) and tokenizer_error is not None :
273
+ raise tokenizer_error
274
+ raise value_error
275
+
276
+ return inputs
277
+
278
+ return self .build_from_dataset_name (
279
+ config ,
280
+ config .dataset ,
281
+ config .num_samples or 32 ,
282
+ dataset_split = dataset_metadata ["split" ],
283
+ preprocess_function = preprocess_function ,
284
+ trust_remote_code = trc ,
285
+ )
286
+ elif isinstance (self .model , OVModelForSpeechSeq2Seq ):
287
+ dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS [config .dataset ]
288
+
289
+ def preprocess_function (item ):
290
+ audio = item
291
+ for key_name in dataset_metadata ["inputs" ]["audio" ]:
292
+ audio = audio [key_name ]
293
+
294
+ sampling_rate = item
295
+ for key_name in dataset_metadata ["inputs" ]["sampling_rate" ]:
296
+ sampling_rate = sampling_rate [key_name ]
297
+
298
+ input_features = processor (audio , sampling_rate = sampling_rate , return_tensors = "pt" ).input_features
299
+
300
+ return input_features
301
+
302
+ return self .build_from_dataset_name (
303
+ config ,
304
+ dataset_metadata ["id" ],
305
+ config .num_samples or 128 ,
306
+ dataset_metadata ["name" ],
307
+ dataset_metadata ["split" ],
308
+ preprocess_function = preprocess_function ,
309
+ trust_remote_code = trc ,
310
+ streaming = dataset_metadata ["streaming" ],
311
+ )
312
+ else :
313
+ raise Exception
250
314
elif isinstance (self , OVDiffusionPipeline ):
251
- dataset = quantization_config .dataset
315
+ dataset = config .dataset
252
316
if isinstance (dataset , str ):
253
- return self ._prepare_diffusion_calibration_data (self , dataset_name = quantization_config .dataset , num_samples = quantization_config .num_samples )
317
+ dataset_name = dataset
318
+ dataset_metadata = PREDEFINED_DIFFUSION_DATASETS [dataset_name ]
319
+
320
+ def preprocess_function (item ):
321
+ return {inp_name : item [column ] for inp_name , column in dataset_metadata ["inputs" ].items ()}
322
+
323
+ dataset = self ._load_dataset (
324
+ dataset_name ,
325
+ dataset_split = dataset_metadata ["split" ],
326
+ preprocess_function = preprocess_function ,
327
+ streaming = dataset_metadata ["streaming" ],
328
+ )
329
+ elif not (isinstance (dataset , list ) and all (isinstance (it , str ) for it in dataset )):
330
+ raise Exception
331
+
332
+ return self .build_from_dataset (config , dataset )
254
333
255
334
def _load_dataset (
256
335
self ,
@@ -306,10 +385,10 @@ def _load_dataset(
306
385
datasets_kwargs ["trust_remote_code" ] = trust_remote_code
307
386
308
387
dataset = load_dataset (dataset_name , ** datasets_kwargs )
388
+ dataset = dataset .shuffle (seed = self .seed )
309
389
310
390
if num_samples is not None :
311
- num_samples = min (num_samples , len (dataset ))
312
- dataset = dataset .shuffle (seed = self .seed ).select (range (num_samples ))
391
+ dataset = dataset .select (range (min (num_samples , len (dataset ))))
313
392
314
393
if preprocess_function is not None :
315
394
dataset = dataset .map (preprocess_function , batched = preprocess_batch )
@@ -347,7 +426,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
347
426
return dataset .remove_columns (ignored_columns )
348
427
349
428
def _prepare_decoder_calibration_data (
350
- self , dataloader : OVDataLoader , num_samples : Optional [ int ] = 200
429
+ self , dataloader : OVDataLoader , num_samples : int = 200
351
430
) -> Dict [str , nncf .Dataset ]:
352
431
# Prefetch past_key_values
353
432
self .model .update_pkv_precision (True )
@@ -385,43 +464,9 @@ def _prepare_causal_lm_calibration_data(self, config: OVQuantizationConfigBase,
385
464
386
465
return {"model" : calibration_dataset }
387
466
388
- def _prepare_visual_causal_lm_calibration_data (self , config : OVQuantizationConfigBase ) -> Dict [str , nncf .Dataset ]:
389
- dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS [config .dataset ]
390
-
391
- def preprocess_function (item ):
392
- inputs_metadata = dataset_metadata ["inputs" ]
393
- return item [inputs_metadata ["instruction" ]], item [inputs_metadata ["image_url" ]]
394
-
395
- num_samples = config .num_samples or 32
396
- dataset = self ._load_dataset (
397
- config .dataset ,
398
- num_samples ,
399
- dataset_split = dataset_metadata ["split" ],
400
- preprocess_function = preprocess_function ,
401
- trust_remote_code = config .trust_remote_code ,
402
- )
403
- dataloader = self ._get_calibration_dataloader (dataset )
404
-
405
- processor = AutoProcessor .from_pretrained (config .processor , trust_remote_code = config .trust_remote_code )
406
- try :
407
- tokenizer = AutoTokenizer .from_pretrained (config .tokenizer , trust_remote_code = config .trust_remote_code )
408
- tokenizer_error = None
409
- except Exception as tokenizer_error : # noqa: F841
410
- tokenizer = None
411
-
412
- calibration_dataset = []
413
- for instruction , image_url in tqdm (dataloader , desc = "Collecting calibration dataset" , total = num_samples ):
414
- image = Image .open (requests .get (image_url , stream = True ).raw )
415
-
416
- try :
417
- inputs = self .model .preprocess_inputs (
418
- text = instruction , image = image , processor = processor , tokenizer = tokenizer , config = self .model .config
419
- )
420
- except ValueError as value_error :
421
- if "Tokenizer is required." in str (value_error ) and tokenizer_error is not None :
422
- raise tokenizer_error
423
- raise value_error
424
-
467
+ def _prepare_visual_causal_lm_calibration_data (self , dataloader : OVDataLoader ) -> Dict [str , nncf .Dataset ]:
468
+ calibration_data = []
469
+ for inputs in tqdm (dataloader , desc = "Collecting calibration dataset" ):
425
470
input_ids = inputs .get ("input_ids" )
426
471
position_ids = torch .arange (input_ids .size (1 )).unsqueeze (0 ).to (input_ids .device )
427
472
@@ -437,35 +482,11 @@ def preprocess_function(item):
437
482
inputs_embeds = inputs_embeds ,
438
483
)
439
484
440
- calibration_dataset .append (language_model_inputs )
441
-
442
- return {"language_model" : nncf .Dataset (calibration_dataset )}
443
-
444
- def _prepare_speech_to_text_calibration_data (self , config : OVQuantizationConfigBase ) -> Dict [str , nncf .Dataset ]:
445
- dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS [config .dataset ]
446
-
447
- def preprocess_function (item ):
448
- audio = item
449
- for key_name in dataset_metadata ["inputs" ]["audio" ]:
450
- audio = audio [key_name ]
451
-
452
- sampling_rate = item
453
- for key_name in dataset_metadata ["inputs" ]["sampling_rate" ]:
454
- sampling_rate = sampling_rate [key_name ]
485
+ calibration_data .append (language_model_inputs )
455
486
456
- return audio , sampling_rate
457
-
458
- num_samples = config .num_samples or 128
459
- dataloader = self ._get_calibration_dataloader (
460
- dataset_metadata ["id" ],
461
- num_samples ,
462
- dataset_metadata ["name" ],
463
- dataset_metadata ["split" ],
464
- preprocess_function = preprocess_function ,
465
- trust_remote_code = config .trust_remote_code ,
466
- streaming = True ,
467
- )
487
+ return {"language_model" : nncf .Dataset (calibration_data )}
468
488
489
+ def _prepare_speech_to_text_calibration_data (self , dataloader : OVDataLoader , num_samples : int ) -> Dict [str , nncf .Dataset ]:
469
490
encoder_calibration_data = []
470
491
encoder_model = self .model .encoder
471
492
encoder_model ._compile ()
@@ -489,13 +510,11 @@ def preprocess_function(item):
489
510
decoder_w_p_model .request , decoder_w_p_calibration_data , apply_caching = True
490
511
)
491
512
492
- processor = AutoProcessor .from_pretrained (config .processor )
493
513
try :
494
514
# Download audio inputs beforehand to avoid possible connection issues
495
515
audio_inputs = list (tqdm (dataloader , desc = "Downloading audio inputs" , total = num_samples ))
496
516
497
- for audio , sampling_rate in tqdm (audio_inputs , desc = "Collecting calibration data" ):
498
- input_features = processor (audio , sampling_rate = sampling_rate , return_tensors = "pt" ).input_features
517
+ for input_features in tqdm (audio_inputs , desc = "Collecting calibration data" ):
499
518
self .model .generate (input_features )
500
519
finally :
501
520
encoder_model .request = encoder_model .request .request
@@ -512,68 +531,40 @@ def preprocess_function(item):
512
531
return datasets
513
532
514
533
def _prepare_diffusion_calibration_data (
515
- self ,
516
- dataloader : Optional [OVDataLoader ] = None ,
517
- dataset_name : Optional [str ] = None ,
518
- num_samples : Optional [int ] = None ,
534
+ self , dataloader : OVDataLoader , num_samples : int = 200
519
535
) -> Dict [str , nncf .Dataset ]:
520
536
self .model .compile ()
521
537
522
- diffuser = self .model .unet if self .model .unet is not None else self .model .transformer
538
+ diffuser_model_name = "unet" if self .model .unet is not None else "transformer"
539
+ diffuser = getattr (self , diffuser_model_name )
523
540
524
541
size = diffuser .config .get ("sample_size" , 64 ) * self .model .vae_scale_factor
525
542
height , width = 2 * (min (size , 512 ),)
526
- num_samples = num_samples or 200
527
-
528
- if dataset is not None :
529
- if isinstance (dataset , nncf .Dataset ):
530
- return dataset
531
- if is_datasets_available () and isinstance (dataset , Dataset ):
532
- dataset = dataset .select_columns (["caption" ])
533
543
534
- def transform_fn (data_item ):
535
- return data_item if isinstance (data_item , (list , dict )) else [data_item ]
536
-
537
- elif isinstance (dataset_name , str ):
538
- available_datasets = PREDEFINED_SD_DATASETS .keys ()
539
- if dataset_name not in available_datasets :
540
- raise ValueError (
541
- f"""You have entered a string value for dataset. You can only choose between
542
- { list (available_datasets )} , but the { dataset_name } was found"""
543
- )
544
-
545
- from datasets import load_dataset
546
-
547
- dataset_metadata = PREDEFINED_SD_DATASETS [dataset_name ]
548
- datasets_kwargs = {"split" : dataset_metadata ["split" ], "streaming" : True }
549
- dataset = load_dataset (dataset_name , ** datasets_kwargs ).shuffle (seed = self .seed )
550
-
551
- input_names = dataset_metadata ["inputs" ]
552
- dataset = dataset .select_columns (list (input_names .values ()))
553
-
554
- def transform_fn (data_item ):
555
- return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
556
-
557
- else :
558
- raise ValueError (
559
- "For UNet inputs collection either quantization_config.dataset or custom "
560
- "calibration_dataset must be provided."
561
- )
544
+ # TODO: move the logic below to ov_quantizer
545
+ # if dataset is not None:
546
+ # if isinstance(dataset, nncf.Dataset):
547
+ # return dataset
548
+ # if is_datasets_available() and isinstance(dataset, Dataset):
549
+ # dataset = dataset.select_columns(["caption"])
550
+ #
551
+ # def transform_fn(data_item):
552
+ # return data_item if isinstance(data_item, (list, dict)) else [data_item]
562
553
563
554
calibration_data = []
564
555
try :
565
556
diffuser .request = InferRequestWrapper (diffuser .request , calibration_data )
566
557
567
- for inputs in dataset :
568
- inputs = transform_fn (inputs )
558
+ for inputs in tqdm (dataloader , desc = "Collecting calibration data" ):
569
559
if isinstance (inputs , dict ):
570
560
self .model (** inputs , height = height , width = width )
561
+ elif isinstance (inputs , str ):
562
+ self .model (inputs , height = height , width = width )
571
563
else :
572
564
self .model (* inputs , height = height , width = width )
573
565
if len (calibration_data ) >= num_samples :
574
566
break
575
567
finally :
576
568
diffuser .request = diffuser .request .request
577
569
578
- calibration_dataset = nncf .Dataset (calibration_data [:num_samples ])
579
- return calibration_dataset
570
+ return {diffuser_model_name : nncf .Dataset (calibration_data [:num_samples ])}
0 commit comments