Skip to content

Commit 2d14e25

Browse files
Deprecate compression options (#565)
* deprecate compression options * style * fix configuration * Update CLI argument * update documentation * deprecate torch nn modules for ov quantizer * fix ov config for fp32 models * fix format * update documentation * Add check for configuration * fix ratio default value for SD models * add quantization_config argument for OVModel * remove commented line * Update docs/source/inference.mdx Co-authored-by: Alexander Kozlov <alexander.kozlov@intel.com> * add default config for causal LM * fix warning message --------- Co-authored-by: Alexander Kozlov <alexander.kozlov@intel.com>
1 parent 62f570f commit 2d14e25

18 files changed

+362
-227
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ from optimum.intel import OVQuantizer, OVModelForSequenceClassification
126126
from transformers import AutoTokenizer, AutoModelForSequenceClassification
127127

128128
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
129-
model = AutoModelForSequenceClassification.from_pretrained(model_id)
129+
model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
130130
tokenizer = AutoTokenizer.from_pretrained(model_id)
131131
def preprocess_fn(examples, tokenizer):
132132
return tokenizer(

docs/source/inference.mdx

+12-14
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ Here we set the `task` to `text-generation-with-past`, with the `-with-past` suf
4747
optimum-cli export openvino --model local_path --task text-generation-with-past ov_model
4848
```
4949

50+
To export your model in fp16, you can add `--weight-format fp16` when exporting your model.
51+
5052
Once the model is exported, you can load the OpenVINO model using :
5153

5254
```python
@@ -96,15 +98,23 @@ tokenizer.save_pretrained(save_directory)
9698

9799
### Weight-only quantization
98100

99-
You can also apply 8-bit or 4-bit weight quantization when exporting your model with the CLI:
101+
You can also apply 8-bit or 4-bit weight quantization when exporting your model with the CLI by setting the `weight-format` argument to respectively `int8` or `int4`:
100102

101103
```bash
102104
optimum-cli export openvino --model gpt2 --weight-format int8 ov_model
103105
```
104106

105107
This will result in the exported model linear and embedding layers to be quantized to INT8 or INT4, the activations will be kept in floating point precision. This type of optimization allows reducing the footprint and latency of LLMs.
106108

107-
This can also be done when loading your model by setting the `load_in_8bit` argument when calling the `from_pretrained()` method.
109+
By default the quantization scheme will be [assymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) you can add `--sym`.
110+
111+
For INT4 quantization you can also specify the following arguments :
112+
* The `--group-size` parameter will define the group size to use for quantization, `-1` it will results in per-column quantization.
113+
* The `--ratio` CLI parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to `int4` while 10% will be quantized to `int8`.
114+
115+
Smaller `group_size` and `ratio` of usually improve accuracy at the sacrifice of the model size and inference latency.
116+
117+
You can also apply 8-bit quantization on your model's weight when loading your model by setting the `load_in_8bit=True` argument when calling the `from_pretrained()` method.
108118

109119
```python
110120
from optimum.intel import OVModelForCausalLM
@@ -114,18 +124,6 @@ model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
114124

115125
> **NOTE:** `load_in_8bit` is enabled by default for the models larger than 1 billion parameters.
116126

117-
There are also alternative compression options for a different performance-accuracy trade-off:
118-
119-
| Option | Description |
120-
|---------------------------------------------------------------------|-------------------|
121-
| `fp16` | Float16 weights |
122-
| `int8` | INT8 weights |
123-
| `int4_sym_g128`, `int4_asym_g128`, `int4_sym_g64`, `int4_asym_g64`* | INT4 weights |
124-
125-
*`sym` and `asym` stand for symmetric and asymmetric quantization, `g128` and `g64` means the group size `128` and `64` respectively.
126-
127-
`--ratio` CLI parameter controls the ratio between 4-bit and 8-bit quantized layers and can also change performance-accuracy trade-off for the optimized model. It is valid only for INT4 quantization options.
128-
129127

130128
To apply quantization on both weights and activations, you can use the `OVQuantizer`, more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov#optimization).
131129

docs/source/optimization_ov.mdx

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ Here is how to apply static quantization on a fine-tuned DistilBERT:
2626

2727
```python
2828
from functools import partial
29-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
30-
from optimum.intel import OVConfig, OVQuantizer
29+
from transformers import AutoTokenizer
30+
from optimum.intel import OVConfig, OVQuantizer, OVModelForSequenceClassification,
3131

3232
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
33-
model = AutoModelForSequenceClassification.from_pretrained(model_id)
33+
model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
3434
tokenizer = AutoTokenizer.from_pretrained(model_id)
3535
# The directory where the quantized model will be saved
3636
save_dir = "ptq_model"

optimum/commands/export/openvino.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
7777
optional_group.add_argument(
7878
"--weight-format",
7979
type=str,
80-
choices=["fp32", "fp16", "int8", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"],
80+
choices=["fp32", "fp16", "int8", "int4", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"],
8181
default=None,
8282
help=(
8383
"The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, int4_* - for INT4 compressed weights."
@@ -86,12 +86,24 @@ def parse_args_openvino(parser: "ArgumentParser"):
8686
optional_group.add_argument(
8787
"--ratio",
8888
type=float,
89-
default=0.8,
89+
default=None,
9090
help=(
9191
"Compression ratio between primary and backup precision. In the case of INT4, NNCF evaluates layer sensitivity and keeps the most impactful layers in INT8"
9292
"precision (by default 20%% in INT8). This helps to achieve better accuracy after weight compression."
9393
),
9494
)
95+
optional_group.add_argument(
96+
"--sym",
97+
action="store_true",
98+
default=None,
99+
help=("Whether to apply symmetric quantization"),
100+
)
101+
optional_group.add_argument(
102+
"--group-size",
103+
type=int,
104+
default=None,
105+
help=("The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization."),
106+
)
95107
optional_group.add_argument(
96108
"--disable-stateful",
97109
action="store_true",
@@ -132,6 +144,7 @@ def parse_args(parser: "ArgumentParser"):
132144

133145
def run(self):
134146
from ...exporters.openvino.__main__ import main_export
147+
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig
135148

136149
if self.args.fp16:
137150
logger.warning(
@@ -144,6 +157,39 @@ def run(self):
144157
)
145158
self.args.weight_format = "int8"
146159

160+
weight_format = self.args.weight_format or "fp32"
161+
162+
ov_config = None
163+
if weight_format in {"fp16", "fp32"}:
164+
ov_config = OVConfig(dtype=weight_format)
165+
else:
166+
is_int8 = weight_format == "int8"
167+
168+
# For int4 quantization if not parameter is provided, then use the default config if exist
169+
if (
170+
not is_int8
171+
and self.args.ratio is None
172+
and self.args.group_size is None
173+
and self.args.sym is None
174+
and self.args.model in _DEFAULT_4BIT_CONFIGS
175+
):
176+
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
177+
else:
178+
quantization_config = {
179+
"bits": 8 if is_int8 else 4,
180+
"ratio": 1 if is_int8 else (self.args.ratio or 0.8),
181+
"sym": self.args.sym or False,
182+
"group_size": -1 if is_int8 else self.args.group_size,
183+
}
184+
185+
if weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
186+
logger.warning(
187+
f"--weight-format {weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
188+
)
189+
quantization_config["sym"] = "asym" not in weight_format
190+
quantization_config["group_size"] = 128 if "128" in weight_format else 64
191+
ov_config = OVConfig(quantization_config=quantization_config)
192+
147193
# TODO : add input shapes
148194
main_export(
149195
model_name_or_path=self.args.model,
@@ -153,8 +199,7 @@ def run(self):
153199
cache_dir=self.args.cache_dir,
154200
trust_remote_code=self.args.trust_remote_code,
155201
pad_token_id=self.args.pad_token_id,
156-
compression_option=self.args.weight_format,
157-
compression_ratio=self.args.ratio,
202+
ov_config=ov_config,
158203
stateful=not self.args.disable_stateful,
159204
convert_tokenizer=self.args.convert_tokenizer,
160205
# **input_shapes,

optimum/exporters/openvino/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from .__main__ import main_export
216
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
317
from .stateful import ensure_stateful_is_available, patch_stateful

optimum/exporters/openvino/__main__.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from pathlib import Path
17-
from typing import Any, Callable, Dict, Optional, Union
17+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
1818

1919
from requests.exceptions import ConnectionError as RequestsConnectionError
2020
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
@@ -41,6 +41,18 @@
4141
]
4242

4343

44+
if TYPE_CHECKING:
45+
from optimum.intel.openvino.configuration import OVConfig
46+
47+
_COMPRESSION_OPTIONS = {
48+
"int8": {"bits": 8},
49+
"int4_sym_g128": {"bits": 4, "sym": True, "group_size": 128},
50+
"int4_asym_g128": {"bits": 4, "sym": False, "group_size": 128},
51+
"int4_sym_g64": {"bits": 4, "sym": True, "group_size": 64},
52+
"int4_asym_g64": {"bits": 4, "sym": False, "group_size": 64},
53+
}
54+
55+
4456
logger = logging.getLogger(__name__)
4557

4658

@@ -63,6 +75,7 @@ def main_export(
6375
fn_get_submodels: Optional[Callable] = None,
6476
compression_option: Optional[str] = None,
6577
compression_ratio: Optional[float] = None,
78+
ov_config: "OVConfig" = None,
6679
stateful: bool = True,
6780
convert_tokenizer: bool = False,
6881
library_name: Optional[str] = None,
@@ -137,6 +150,29 @@ def main_export(
137150
>>> main_export("gpt2", output="gpt2_onnx/")
138151
```
139152
"""
153+
154+
if compression_option is not None:
155+
logger.warning(
156+
"The `compression_option` argument is deprecated and will be removed in optimum-intel v1.17.0. "
157+
"Please, pass an `ov_config` argument instead `OVConfig(..., quantization_config=quantization_config)`."
158+
)
159+
160+
if compression_ratio is not None:
161+
logger.warning(
162+
"The `compression_ratio` argument is deprecated and will be removed in optimum-intel v1.17.0. "
163+
"Please, pass an `ov_config` argument instead `OVConfig(quantization_config={ratio=compression_ratio})`."
164+
)
165+
166+
if ov_config is None and compression_option is not None:
167+
from ...intel.openvino.configuration import OVConfig
168+
169+
if compression_option == "fp16":
170+
ov_config = OVConfig(dtype="fp16")
171+
elif compression_option != "fp32":
172+
q_config = _COMPRESSION_OPTIONS[compression_option] if compression_option in _COMPRESSION_OPTIONS else {}
173+
q_config["ratio"] = compression_ratio or 1.0
174+
ov_config = OVConfig(quantization_config=q_config)
175+
140176
original_task = task
141177
task = TasksManager.map_from_synonym(task)
142178
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
@@ -293,8 +329,7 @@ class StoreAttr(object):
293329
model=model,
294330
output=output,
295331
task=task,
296-
compression_option=compression_option,
297-
compression_ratio=compression_ratio,
332+
ov_config=ov_config,
298333
stateful=stateful,
299334
model_kwargs=model_kwargs,
300335
custom_onnx_configs=custom_onnx_configs,

0 commit comments

Comments
 (0)