@@ -572,7 +572,7 @@ def _from_pretrained(
572
572
from_onnx : bool = False ,
573
573
local_files_only : bool = False ,
574
574
load_in_8bit : bool = False ,
575
- quantization_config : Union [OVWeightQuantizationConfig , Dict ] = None ,
575
+ quantization_config : Optional [ Union [OVWeightQuantizationConfig , Dict ] ] = None ,
576
576
** kwargs ,
577
577
):
578
578
model_path = Path (model_id )
@@ -596,7 +596,12 @@ def _from_pretrained(
596
596
quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
597
597
598
598
load_in_4bit = quantization_config .bits == 4 if quantization_config else False
599
- model = cls .load_model (model_cache_path , quantization_config = None if load_in_4bit else quantization_config )
599
+ calibration_dataset = kwargs .get ("calibration_dataset" , None )
600
+ model = cls .load_model (
601
+ model_cache_path ,
602
+ quantization_config = None if load_in_4bit else quantization_config ,
603
+ calibration_dataset = calibration_dataset ,
604
+ )
600
605
601
606
model_type = config .model_type .replace ("_" , "-" )
602
607
if model_type == "bloom" :
@@ -632,7 +637,7 @@ def _from_pretrained(
632
637
f"For the given model, we recommend the following `quantization_config` : { default_config } "
633
638
)
634
639
635
- if isinstance (quantization_config .dataset , str ):
640
+ if calibration_dataset is None and isinstance (quantization_config .dataset , str ):
636
641
tokenizer = quantization_config .tokenizer or AutoTokenizer .from_pretrained (model_id )
637
642
638
643
from optimum .gptq .data import get_dataset , prepare_dataset
@@ -644,9 +649,9 @@ def _from_pretrained(
644
649
dataset = get_dataset (quantization_config .dataset , tokenizer , seqlen = 32 , nsamples = nsamples )
645
650
dataset = prepare_dataset (dataset )
646
651
quantization_config = copy .deepcopy (quantization_config )
647
- quantization_config . dataset = nncf .Dataset (dataset , lambda x : causal_model .prepare_inputs (** x ))
652
+ calibration_dataset = nncf .Dataset (dataset , lambda x : causal_model .prepare_inputs (** x ))
648
653
649
- _weight_only_quantization (model , quantization_config )
654
+ _weight_only_quantization (model , quantization_config , calibration_dataset )
650
655
651
656
return causal_model
652
657
0 commit comments