12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import copy
15
16
import logging
16
17
import os
17
18
from pathlib import Path
@@ -100,6 +101,7 @@ def __init__(
100
101
dynamic_shapes : bool = True ,
101
102
ov_config : Optional [Dict [str , str ]] = None ,
102
103
model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
104
+ quantization_config : Optional [Union [OVWeightQuantizationConfig , Dict ]] = None ,
103
105
** kwargs ,
104
106
):
105
107
if not dynamic_shapes :
@@ -117,6 +119,7 @@ def __init__(
117
119
dynamic_shapes = False ,
118
120
ov_config = ov_config ,
119
121
model_save_dir = model_save_dir ,
122
+ quantization_config = quantization_config ,
120
123
** kwargs ,
121
124
)
122
125
@@ -224,6 +227,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
224
227
dst_path = os .path .join (save_directory , OV_XML_FILE_NAME )
225
228
openvino .save_model (model_to_save , dst_path , compress_to_fp16 = False )
226
229
230
+ self ._save_openvino_config (save_directory )
231
+
227
232
@classmethod
228
233
def _from_transformers (
229
234
cls ,
@@ -578,15 +583,10 @@ def _from_pretrained(
578
583
local_files_only = local_files_only ,
579
584
)
580
585
581
- # Give default quantization config if not provided and load_in_8bit=True
582
- if load_in_8bit :
583
- quantization_config = quantization_config or {"bits" : 8 }
584
-
585
- if isinstance (quantization_config , dict ):
586
- if quantization_config == {"bits" : 4 } and config .name_or_path in _DEFAULT_4BIT_CONFIGS :
587
- quantization_config = _DEFAULT_4BIT_CONFIGS [config .name_or_path ]
586
+ if isinstance (quantization_config , dict ) and quantization_config == {"bits" : 4 }:
587
+ quantization_config = _DEFAULT_4BIT_CONFIGS .get (config .name_or_path , quantization_config )
588
588
589
- quantization_config = OVWeightQuantizationConfig . from_dict (quantization_config )
589
+ quantization_config = cls . _prepare_weight_quantization_config (quantization_config , load_in_8bit )
590
590
591
591
load_in_4bit = quantization_config .bits == 4 if quantization_config else False
592
592
model = cls .load_model (model_cache_path , quantization_config = None if load_in_4bit else quantization_config )
@@ -605,7 +605,12 @@ def _from_pretrained(
605
605
606
606
enable_compilation = kwargs .pop ("compile" , True ) and not load_in_4bit
607
607
causal_model = init_cls (
608
- model = model , config = config , model_save_dir = model_cache_path .parent , compile = enable_compilation , ** kwargs
608
+ model = model ,
609
+ config = config ,
610
+ model_save_dir = model_cache_path .parent ,
611
+ compile = enable_compilation ,
612
+ quantization_config = quantization_config ,
613
+ ** kwargs ,
609
614
)
610
615
611
616
if load_in_4bit :
@@ -634,6 +639,7 @@ def _from_pretrained(
634
639
# seqlen = get_seqlen(causal_model)
635
640
dataset = get_dataset (quantization_config .dataset , tokenizer , seqlen = 32 )
636
641
dataset = prepare_dataset (dataset )
642
+ quantization_config = copy .deepcopy (quantization_config )
637
643
quantization_config .dataset = nncf .Dataset (dataset , lambda x : causal_model .prepare_inputs (** x ))
638
644
639
645
_weight_only_quantization (model , quantization_config )
0 commit comments