Skip to content

Commit bf91033

Browse files
committed
add quantization_config argument for OVModel
1 parent 925e56d commit bf91033

File tree

6 files changed

+111
-66
lines changed

6 files changed

+111
-66
lines changed

optimum/exporters/openvino/convert.py

+8-20
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,8 @@ def export(
120120
device (`str`, *optional*, defaults to `cpu`):
121121
The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
122122
export on CUDA devices.
123-
compression_option (`Optional[str]`, defaults to `None`):
124-
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
125-
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
126-
compression_ratio (`Optional[float]`, defaults to `None`):
127-
Compression ratio between primary and backup precision (only relevant to INT4).
123+
ov_config (`OVConfig`, *optional*):
124+
The configuration containing the parameters related to quantization.
128125
input_shapes (`Optional[Dict]`, defaults to `None`):
129126
If specified, allows to use specific shapes for the example input provided to the exporter.
130127
stateful (`bool`, defaults to `True`):
@@ -233,11 +230,8 @@ def export_pytorch_via_onnx(
233230
If specified, allows to use specific shapes for the example input provided to the exporter.
234231
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
235232
Additional kwargs for model export.
236-
compression_option (`Optional[str]`, defaults to `None`):
237-
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
238-
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
239-
compression_ratio (`Optional[float]`, defaults to `None`):
240-
Compression ratio between primary and backup precision (only relevant to INT4).
233+
ov_config (`OVConfig`, *optional*):
234+
The configuration containing the parameters related to quantization.
241235
242236
Returns:
243237
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
@@ -290,11 +284,8 @@ def export_pytorch(
290284
If specified, allows to use specific shapes for the example input provided to the exporter.
291285
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
292286
Additional kwargs for model export
293-
compression_option (`Optional[str]`, defaults to `None`):
294-
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
295-
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
296-
compression_ratio (`Optional[float]`, defaults to `None`):
297-
Compression ratio between primary and backup precision (only relevant to INT4).
287+
ov_config (`OVConfig`, *optional*):
288+
The configuration containing the parameters related to quantization.
298289
stateful (`bool`, defaults to `False`):
299290
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
300291
@@ -452,11 +443,8 @@ def export_models(
452443
export on CUDA devices.
453444
input_shapes (Optional[Dict], optional, Defaults to None):
454445
If specified, allows to use specific shapes for the example input provided to the exporter.
455-
compression_option (`Optional[str]`, defaults to `None`):
456-
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
457-
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
458-
compression_ratio (`Optional[int]`, defaults to `None`):
459-
Compression ratio between primary and backup precision (only relevant to INT4).
446+
ov_config (`OVConfig`, *optional*):
447+
The configuration containing the parameters related to quantization.
460448
model_kwargs (Optional[Dict[str, Any]], optional):
461449
Additional kwargs for model export.
462450
stateful (`bool`, defaults to `True`)

optimum/intel/openvino/modeling.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from pathlib import Path
1818
from tempfile import TemporaryDirectory
19-
from typing import Optional, Union
19+
from typing import Dict, Optional, Union
2020

2121
import numpy as np
2222
import openvino
@@ -53,6 +53,7 @@
5353

5454
from ...exporters.openvino import main_export
5555
from ..utils.import_utils import is_timm_available, is_timm_version
56+
from .configuration import OVConfig, OVWeightQuantizationConfig
5657
from .modeling_base import OVBaseModel
5758
from .utils import _is_timm_ov_dir
5859

@@ -427,14 +428,17 @@ def _from_transformers(
427428
task: Optional[str] = None,
428429
trust_remote_code: bool = False,
429430
load_in_8bit: Optional[bool] = None,
430-
load_in_4bit: Optional[bool] = None,
431+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
431432
**kwargs,
432433
):
433434
save_dir = TemporaryDirectory()
434435
save_dir_path = Path(save_dir.name)
435436

436-
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
437-
compression_option = "fp32" if load_in_8bit is not None else None
437+
# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
438+
if load_in_8bit is None or not quantization_config:
439+
ov_config = None
440+
else:
441+
ov_config = OVConfig(dtype="fp32")
438442

439443
# OVModelForFeatureExtraction works with Transformers type of models, thus even sentence-transformers models are loaded as such.
440444
main_export(
@@ -448,12 +452,18 @@ def _from_transformers(
448452
local_files_only=local_files_only,
449453
force_download=force_download,
450454
trust_remote_code=trust_remote_code,
451-
compression_option=compression_option,
455+
ov_config=ov_config,
452456
library_name="transformers",
453457
)
454458

455459
config.save_pretrained(save_dir_path)
456-
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)
460+
return cls._from_pretrained(
461+
model_id=save_dir_path,
462+
config=config,
463+
load_in_8bit=load_in_8bit,
464+
quantization_config=quantization_config,
465+
**kwargs,
466+
)
457467

458468

459469
MASKED_LM_EXAMPLE = r"""

optimum/intel/openvino/modeling_base.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from ...exporters.openvino import export, main_export
3333
from ..utils.import_utils import is_nncf_available
34+
from .configuration import OVConfig, OVWeightQuantizationConfig
3435
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, _print_compiled_model_properties
3536

3637

@@ -91,7 +92,7 @@ def __init__(
9192
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
9293

9394
@staticmethod
94-
def load_model(file_name: Union[str, Path], load_in_8bit: bool = False):
95+
def load_model(file_name: Union[str, Path], quantization_config: Union[OVWeightQuantizationConfig, Dict] = None):
9596
"""
9697
Loads the model.
9798
@@ -118,14 +119,15 @@ def fix_op_names_duplicates(model: openvino.runtime.Model):
118119
if file_name.suffix == ".onnx":
119120
model = fix_op_names_duplicates(model) # should be called during model conversion to IR
120121

121-
if load_in_8bit:
122+
if quantization_config:
122123
if not is_nncf_available():
123124
raise ImportError(
124125
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
125126
)
126-
import nncf
127127

128-
model = nncf.compress_weights(model)
128+
from optimum.intel.openvino.quantization import _weight_only_quantization
129+
130+
model = _weight_only_quantization(model, quantization_config)
129131

130132
return model
131133

@@ -155,6 +157,7 @@ def _from_pretrained(
155157
from_onnx: bool = False,
156158
local_files_only: bool = False,
157159
load_in_8bit: bool = False,
160+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
158161
**kwargs,
159162
):
160163
"""
@@ -199,7 +202,12 @@ def _from_pretrained(
199202
subfolder=subfolder,
200203
local_files_only=local_files_only,
201204
)
202-
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
205+
206+
# Give default quantization config if not provided and load_in_8bit=True
207+
if load_in_8bit:
208+
quantization_config = quantization_config or {"bits": 8}
209+
210+
model = cls.load_model(model_cache_path, quantization_config=quantization_config)
203211
return cls(model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
204212

205213
@staticmethod
@@ -252,6 +260,7 @@ def _from_transformers(
252260
task: Optional[str] = None,
253261
trust_remote_code: bool = False,
254262
load_in_8bit: Optional[bool] = None,
263+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
255264
**kwargs,
256265
):
257266
"""
@@ -275,10 +284,11 @@ def _from_transformers(
275284
save_dir = TemporaryDirectory()
276285
save_dir_path = Path(save_dir.name)
277286

278-
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
279-
compression_option = None
280-
if load_in_8bit is not None:
281-
compression_option = "fp32"
287+
# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
288+
if load_in_8bit is None or not quantization_config:
289+
ov_config = None
290+
else:
291+
ov_config = OVConfig(dtype="fp32")
282292

283293
main_export(
284294
model_name_or_path=model_id,
@@ -291,11 +301,17 @@ def _from_transformers(
291301
local_files_only=local_files_only,
292302
force_download=force_download,
293303
trust_remote_code=trust_remote_code,
294-
compression_option=compression_option,
304+
ov_config=ov_config,
295305
)
296306

297307
config.save_pretrained(save_dir_path)
298-
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)
308+
return cls._from_pretrained(
309+
model_id=save_dir_path,
310+
config=config,
311+
load_in_8bit=load_in_8bit,
312+
quantization_config=quantization_config,
313+
**kwargs,
314+
)
299315

300316
@classmethod
301317
def _to_load(

optimum/intel/openvino/modeling_base_seq2seq.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.file_utils import add_start_docstrings
2626

2727
from ...exporters.openvino import main_export
28+
from .configuration import OVConfig, OVWeightQuantizationConfig
2829
from .modeling_base import OVBaseModel
2930
from .utils import (
3031
ONNX_DECODER_NAME,
@@ -111,6 +112,7 @@ def _from_pretrained(
111112
use_cache: bool = True,
112113
from_onnx: bool = False,
113114
load_in_8bit: bool = False,
115+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
114116
**kwargs,
115117
):
116118
"""
@@ -152,12 +154,19 @@ def _from_pretrained(
152154
decoder_file_name = decoder_file_name or default_decoder_file_name
153155
decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name
154156
decoder_with_past = None
157+
158+
# Give default quantization config if not provided and load_in_8bit=True
159+
if load_in_8bit:
160+
quantization_config = quantization_config or {"bits": 8}
161+
155162
# Load model from a local directory
156163
if os.path.isdir(model_id):
157-
encoder = cls.load_model(os.path.join(model_id, encoder_file_name), load_in_8bit)
158-
decoder = cls.load_model(os.path.join(model_id, decoder_file_name), load_in_8bit)
164+
encoder = cls.load_model(os.path.join(model_id, encoder_file_name), quantization_config)
165+
decoder = cls.load_model(os.path.join(model_id, decoder_file_name), quantization_config)
159166
if use_cache:
160-
decoder_with_past = cls.load_model(os.path.join(model_id, decoder_with_past_file_name), load_in_8bit)
167+
decoder_with_past = cls.load_model(
168+
os.path.join(model_id, decoder_with_past_file_name), quantization_config
169+
)
161170

162171
model_save_dir = Path(model_id)
163172

@@ -185,10 +194,10 @@ def _from_pretrained(
185194
file_names[name] = model_cache_path
186195

187196
model_save_dir = Path(model_cache_path).parent
188-
encoder = cls.load_model(file_names["encoder"], load_in_8bit)
189-
decoder = cls.load_model(file_names["decoder"], load_in_8bit)
197+
encoder = cls.load_model(file_names["encoder"], quantization_config)
198+
decoder = cls.load_model(file_names["decoder"], quantization_config)
190199
if use_cache:
191-
decoder_with_past = cls.load_model(file_names["decoder_with_past"], load_in_8bit)
200+
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)
192201

193202
return cls(
194203
encoder=encoder,
@@ -214,6 +223,7 @@ def _from_transformers(
214223
use_cache: bool = True,
215224
trust_remote_code: bool = False,
216225
load_in_8bit: Optional[bool] = None,
226+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
217227
**kwargs,
218228
):
219229
"""
@@ -240,13 +250,15 @@ def _from_transformers(
240250

241251
if task is None:
242252
task = cls.export_feature
243-
244253
if use_cache:
245254
task = task + "-with-past"
246255

247-
compression_option = None
248-
if load_in_8bit is not None:
249-
compression_option = "fp32"
256+
# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
257+
if load_in_8bit is None or not quantization_config:
258+
ov_config = None
259+
else:
260+
ov_config = OVConfig(dtype="fp32")
261+
250262
main_export(
251263
model_name_or_path=model_id,
252264
output=save_dir_path,
@@ -258,12 +270,17 @@ def _from_transformers(
258270
local_files_only=local_files_only,
259271
force_download=force_download,
260272
trust_remote_code=trust_remote_code,
261-
compression_option=compression_option,
273+
ov_config=ov_config,
262274
)
263275

264276
config.save_pretrained(save_dir_path)
265277
return cls._from_pretrained(
266-
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs
278+
model_id=save_dir_path,
279+
config=config,
280+
use_cache=use_cache,
281+
load_in_8bit=load_in_8bit,
282+
quantization_config=quantization_config,
283+
**kwargs,
267284
)
268285

269286
def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):

optimum/intel/openvino/modeling_decoder.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ...exporters.openvino.stateful import model_has_state
3535
from ..utils.import_utils import is_nncf_available
3636
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
37-
from .configuration import OVWeightQuantizationConfig, _check_default_4bit_configs
37+
from .configuration import OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
3838
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
3939
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
4040

@@ -252,14 +252,14 @@ def _from_transformers(
252252

253253
if task is None:
254254
task = cls.export_feature
255-
256255
if use_cache:
257256
task = task + "-with-past"
258257

259-
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
260-
compression_option = None
261-
if load_in_8bit is not None or quantization_config is not None:
262-
compression_option = "fp32"
258+
# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
259+
if load_in_8bit is None or not quantization_config:
260+
ov_config = None
261+
else:
262+
ov_config = OVConfig(dtype="fp32")
263263

264264
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
265265

@@ -274,7 +274,7 @@ def _from_transformers(
274274
local_files_only=local_files_only,
275275
force_download=force_download,
276276
trust_remote_code=trust_remote_code,
277-
compression_option=compression_option,
277+
ov_config=ov_config,
278278
stateful=stateful,
279279
)
280280

@@ -285,8 +285,8 @@ def _from_transformers(
285285
model_id=save_dir_path,
286286
config=config,
287287
use_cache=use_cache,
288-
load_in_8bit=load_in_8bit,
289288
stateful=None,
289+
load_in_8bit=load_in_8bit,
290290
quantization_config=quantization_config,
291291
**kwargs,
292292
)
@@ -576,11 +576,15 @@ def _from_pretrained(
576576
local_files_only=local_files_only,
577577
)
578578

579+
# Give default quantization config if not provided and load_in_8bit=True
580+
if load_in_8bit:
581+
quantization_config = quantization_config or {"bits": 8}
582+
579583
if isinstance(quantization_config, dict):
580584
quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
581585

582586
load_in_4bit = quantization_config.bits == 4 if quantization_config else False
583-
model = cls.load_model(model_cache_path, load_in_8bit=False if load_in_4bit else load_in_8bit)
587+
model = cls.load_model(model_cache_path, quantization_config=None if load_in_4bit else quantization_config)
584588

585589
model_type = config.model_type.replace("_", "-")
586590
if model_type == "bloom":

0 commit comments

Comments
 (0)