@@ -355,7 +355,7 @@ def _quantize_ovbasemodel(
355
355
from optimum .intel import OVModelForCausalLM
356
356
357
357
if isinstance (self .model , OVModelForCausalLM ):
358
- calibration_dataset = self ._prepare_builtin_dataset (quantization_config )
358
+ calibration_dataset = self ._prepare_causal_lm_dataset (quantization_config )
359
359
elif is_diffusers_available () and isinstance (self .model , OVStableDiffusionPipelineBase ):
360
360
calibration_dataset = self ._prepare_unet_dataset (
361
361
quantization_config .num_samples , dataset_name = quantization_config .dataset
@@ -669,19 +669,20 @@ def _remove_unused_columns(self, dataset: "Dataset"):
669
669
ignored_columns = list (set (dataset .column_names ) - set (self ._signature_columns ))
670
670
return dataset .remove_columns (ignored_columns )
671
671
672
- def _prepare_builtin_dataset (self , quantization_config : OVWeightQuantizationConfig ):
672
+ def _prepare_causal_lm_dataset (self , quantization_config : OVWeightQuantizationConfig ):
673
673
from optimum .gptq .data import get_dataset , prepare_dataset
674
674
675
675
tokenizer = AutoTokenizer .from_pretrained (
676
676
quantization_config .tokenizer , trust_remote_code = quantization_config .trust_remote_code
677
677
)
678
678
nsamples = quantization_config .num_samples if quantization_config .num_samples else 128
679
- if isinstance (quantization_config .dataset , str ):
680
- calibration_dataset = get_dataset (quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples )
679
+ config_dataset = quantization_config .dataset
680
+ if isinstance (config_dataset , str ):
681
+ calibration_dataset = get_dataset (config_dataset , tokenizer , seqlen = 32 , nsamples = nsamples )
682
+ elif isinstance (config_dataset , list ) and all ([isinstance (it , str ) for it in config_dataset ]):
683
+ calibration_dataset = [tokenizer (text , return_tensors = "pt" ) for text in config_dataset [:nsamples ]]
681
684
else :
682
- calibration_dataset = [
683
- tokenizer (text , return_tensors = "pt" ) for text in quantization_config .dataset [:nsamples ]
684
- ]
685
+ raise ValueError ("Please provide dataset as one of the accepted dataset labels or as a list of strings." )
685
686
calibration_dataset = prepare_dataset (calibration_dataset )
686
687
calibration_dataset = nncf .Dataset (calibration_dataset , lambda x : self .model .prepare_inputs (** x ))
687
688
0 commit comments