Skip to content

Commit e549a7c

Browse files
Address comments
1 parent 6b1cfb4 commit e549a7c

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

optimum/intel/openvino/quantization.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _quantize_ovbasemodel(
355355
from optimum.intel import OVModelForCausalLM
356356

357357
if isinstance(self.model, OVModelForCausalLM):
358-
calibration_dataset = self._prepare_builtin_dataset(quantization_config)
358+
calibration_dataset = self._prepare_causal_lm_dataset(quantization_config)
359359
elif is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase):
360360
calibration_dataset = self._prepare_unet_dataset(
361361
quantization_config.num_samples, dataset_name=quantization_config.dataset
@@ -669,19 +669,20 @@ def _remove_unused_columns(self, dataset: "Dataset"):
669669
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
670670
return dataset.remove_columns(ignored_columns)
671671

672-
def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConfig):
672+
def _prepare_causal_lm_dataset(self, quantization_config: OVWeightQuantizationConfig):
673673
from optimum.gptq.data import get_dataset, prepare_dataset
674674

675675
tokenizer = AutoTokenizer.from_pretrained(
676676
quantization_config.tokenizer, trust_remote_code=quantization_config.trust_remote_code
677677
)
678678
nsamples = quantization_config.num_samples if quantization_config.num_samples else 128
679-
if isinstance(quantization_config.dataset, str):
680-
calibration_dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
679+
config_dataset = quantization_config.dataset
680+
if isinstance(config_dataset, str):
681+
calibration_dataset = get_dataset(config_dataset, tokenizer, seqlen=32, nsamples=nsamples)
682+
elif isinstance(config_dataset, list) and all([isinstance(it, str) for it in config_dataset]):
683+
calibration_dataset = [tokenizer(text, return_tensors="pt") for text in config_dataset[:nsamples]]
681684
else:
682-
calibration_dataset = [
683-
tokenizer(text, return_tensors="pt") for text in quantization_config.dataset[:nsamples]
684-
]
685+
raise ValueError("Please provide dataset as one of the accepted dataset labels or as a list of strings.")
685686
calibration_dataset = prepare_dataset(calibration_dataset)
686687
calibration_dataset = nncf.Dataset(calibration_dataset, lambda x: self.model.prepare_inputs(**x))
687688

0 commit comments

Comments
 (0)