Skip to content

Commit 6c8fa79

Browse files
authored
Add bits and sym parameters to the OV quantization config (#560)
* Add bits and sym parameters to the OV quantization config * format * add nncf version * Fix config saving * add ov config test * remove load_in_4bit argument * add weight only quant for int8 * fix style * add nncf check * remove _int4_weight_only_quantization * fix typo * fix style
1 parent 7e1a21e commit 6c8fa79

File tree

9 files changed

+238
-232
lines changed

9 files changed

+238
-232
lines changed

optimum/exporters/openvino/convert.py

+2
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ def export_models(
500500
Returns:
501501
list of input_names and output_names from ONNX configuration
502502
"""
503+
504+
# TODO : modify compression_option to quantization_config
503505
outputs = []
504506

505507
if output_names is not None and len(output_names) != len(models_and_onnx_configs):

optimum/intel/openvino/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@
3636

3737
patch_torch_operators()
3838

39-
from .configuration import OVConfig
39+
from .configuration import OVConfig, OVWeightQuantizationConfig
4040
from .quantization import OVQuantizer
4141
from .trainer import OVTrainer
4242
from .training_args import OVTrainingArguments
43-
from .weight_quantization import OVWeightQuantizationConfig
4443

4544
from .modeling import (
4645
OVModelForAudioClassification,

optimum/intel/openvino/configuration.py

+113-3
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, Optional, Union
15+
from dataclasses import dataclass
16+
from typing import Any, Dict, List, Optional, Union
1617

1718
import torch
19+
from transformers import PretrainedConfig
1820
from transformers.utils.quantization_config import QuantizationConfigMixin
1921

2022
from optimum.configuration_utils import BaseConfig
2123

22-
from .weight_quantization import OVWeightQuantizationConfig
23-
2424

2525
DEFAULT_QUANTIZATION_CONFIG = {
2626
"algorithm": "quantization",
@@ -77,6 +77,28 @@
7777
}
7878

7979

80+
DEFAULT_4BIT_CONFIGS = {
81+
"databricks/dolly-v2-3b": {"bits": 4, "sym": False, "group_size": 32, "ratio": 0.5},
82+
"EleutherAI/gpt-j-6b": {"bits": 4, "sym": False, "group_size": 64},
83+
"facebook/opt-6.7b": {"bits": 4, "sym": False, "group_size": 64, "ratio": 0.8},
84+
"bigscience/bloomz-7b1": {"bits": 4, "sym": False, "group_size": 32, "ratio": 0.6},
85+
"togethercomputer/RedPajama-INCITE-7B-Instruct": {"bits": 4, "sym": False, "group_size": 128},
86+
"HuggingFaceH4/zephyr-7b-beta": {"bits": 4, "sym": True, "group_size": 64, "ratio": 0.6},
87+
"meta-llama/Llama-2-7b": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.6},
88+
"meta-llama/Llama-2-7b-chat": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.8},
89+
"meta-llama/Llama-2-13b-chat": {"bits": 4, "sym": True, "group_size": 64, "ratio": 0.8},
90+
"stabilityai/stablelm-3b-4e1t": {"bits": 4, "sym": True, "group_size": 64, "ratio": 0.8},
91+
"stablelm-epoch-3b-preview": {"bits": 4, "sym": True, "group_size": 64, "ratio": 0.8},
92+
"stable-zephyr-3b-dpo": {"bits": 4, "sym": False, "group_size": 64, "ratio": 0.8},
93+
"pansophic/rocket-3B": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.8},
94+
"THUDM/chatglm2-6b": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.72},
95+
"Qwen/Qwen-7B-Chat": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.6},
96+
"openlm-research/open_llama_3b": {"bits": 4, "sym": True, "group_size": 64, "all_layers": True},
97+
"tiiuae/falcon-7b": {"bits": 4, "sym": True, "group_size": 64, "all_layers": True},
98+
"psmathur/orca_mini_3b": {"bits": 4, "sym": True, "group_size": 64, "all_layers": True},
99+
}
100+
101+
80102
class OVConfig(BaseConfig):
81103
CONFIG_NAME = "openvino_config.json"
82104
FULL_CONFIGURATION_FILE = "openvino_config.json"
@@ -127,3 +149,91 @@ def _enable_standard_onnx_export_option(self):
127149
for i, algo_config in enumerate(self.compression):
128150
if algo_config["algorithm"] == "quantization":
129151
self.compression[i]["export_to_onnx_standard_ops"] = self.save_onnx_model
152+
153+
154+
@dataclass
155+
class OVWeightQuantizationConfig(QuantizationConfigMixin):
156+
"""
157+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
158+
loaded using `optimum-intel` api for quantization with NNCF.
159+
160+
Args:
161+
162+
bits (`int`, defaults to 8):
163+
The number of bits to quantize to.
164+
sym (`bool`, *optional*, defaults to `False`):
165+
Whether to use symetric quantization.
166+
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
167+
The tokenizer used to process the dataset. You can pass either:
168+
- A custom tokenizer object.
169+
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
170+
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
171+
user or organization name, like `dbmdz/bert-base-german-cased`.
172+
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
173+
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
174+
dataset (`Union[List[str]]`, *optional*):
175+
The dataset used for data-aware compression. You can provide your own dataset in a list of string or just use the
176+
the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new']
177+
group_size (`int`, *optional*, defaults to 128):
178+
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
179+
ratio (`float`, *optional*, defaults to 1.0):
180+
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
181+
and the rest to INT8_ASYM).
182+
all_layers (`bool`, *optional*):
183+
Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit presicion.
184+
sensitivity_metric (`nncf.SensitivityMetric`, *optional*):
185+
The sensitivity metric for assigning quantization precision to layers. In order to
186+
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
187+
awq (`bool`, *optional*):
188+
Enables AWQ method to unify weight ranges and improve overall model accuracy.
189+
ignored_scope (`nncf.IgnoredScope`, *optional*):
190+
An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization.
191+
192+
"""
193+
194+
def __init__(
195+
self,
196+
bits: int = 8,
197+
sym: bool = False,
198+
tokenizer: Any = None,
199+
dataset: Optional[str] = None,
200+
ratio: Optional[float] = None,
201+
group_size: Optional[int] = None,
202+
all_layers: Optional[bool] = None,
203+
sensitivity_metric: Optional[str] = None,
204+
ignored_scope: Optional[dict] = None,
205+
**kwargs,
206+
):
207+
self.bits = bits
208+
self.sym = sym
209+
self.tokenizer = tokenizer
210+
self.dataset = dataset
211+
self.group_size = group_size
212+
self.ratio = ratio
213+
self.all_layers = all_layers
214+
self.sensitivity_metric = sensitivity_metric
215+
self.ignored_scope = ignored_scope
216+
self.quant_method = "default" # TODO : enable AWQ after nncf v2.9.0 release
217+
self.post_init()
218+
219+
def post_init(self):
220+
r"""
221+
Safety checker that arguments are correct
222+
"""
223+
if self.ratio is not None and not (0 <= self.ratio <= 1):
224+
raise ValueError("damp_percent must between 0 and 1.")
225+
if self.group_size is not None and self.group_size != -1 and self.group_size <= 0:
226+
raise ValueError("group_size must be greater than 0 or equal to -1")
227+
if self.dataset is not None and isinstance(self.dataset, str):
228+
if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]:
229+
raise ValueError(
230+
f"""You have entered a string value for dataset. You can only choose between
231+
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
232+
)
233+
234+
if self.bits not in [4, 8]:
235+
raise ValueError(f"Only support quantization to [4,8] bits but found {self.bits}")
236+
237+
238+
def _check_default_4bit_configs(config: PretrainedConfig):
239+
return DEFAULT_4BIT_CONFIGS.get(config.name_or_path, None)

optimum/intel/openvino/modeling_base.py

-6
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def _from_pretrained(
155155
from_onnx: bool = False,
156156
local_files_only: bool = False,
157157
load_in_8bit: bool = False,
158-
load_in_4bit: bool = False,
159158
**kwargs,
160159
):
161160
"""
@@ -185,11 +184,7 @@ def _from_pretrained(
185184
Whether or not to only look at local files (i.e., do not try to download the model).
186185
load_in_8bit (`bool`, *optional*, defaults to `False`):
187186
Whether or not to apply 8-bit weight quantization.
188-
load_in_4bit (`bool`, *optional*, defaults to `False`):
189-
Whether or not to apply 4-bit weight quantization.
190187
"""
191-
if load_in_4bit:
192-
raise ValueError("load_in_4bit is available for OVModelForCausalLM only.")
193188
model_path = Path(model_id)
194189
default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME
195190
file_name = file_name or default_file_name
@@ -257,7 +252,6 @@ def _from_transformers(
257252
task: Optional[str] = None,
258253
trust_remote_code: bool = False,
259254
load_in_8bit: Optional[bool] = None,
260-
load_in_4bit: Optional[bool] = None,
261255
**kwargs,
262256
):
263257
"""

optimum/intel/openvino/modeling_decoder.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232

3333
from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
3434
from ...exporters.openvino.stateful import model_has_state
35+
from ..utils.import_utils import is_nncf_available
3536
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
37+
from .configuration import OVWeightQuantizationConfig, _check_default_4bit_configs
3638
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
3739
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
38-
from .weight_quantization import OVWeightQuantizationConfig, compress_decoder_weights
3940

4041

4142
logger = logging.getLogger(__name__)
@@ -238,7 +239,6 @@ def _from_transformers(
238239
use_cache: bool = True,
239240
trust_remote_code: bool = False,
240241
load_in_8bit: Optional[bool] = None,
241-
load_in_4bit: Optional[bool] = None,
242242
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
243243
**kwargs,
244244
):
@@ -258,8 +258,9 @@ def _from_transformers(
258258

259259
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
260260
compression_option = None
261-
if load_in_8bit is not None or load_in_4bit is not None:
261+
if load_in_8bit is not None or quantization_config is not None:
262262
compression_option = "fp32"
263+
263264
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
264265
main_export(
265266
model_name_or_path=model_id,
@@ -285,7 +286,6 @@ def _from_transformers(
285286
use_cache=use_cache,
286287
load_in_8bit=load_in_8bit,
287288
stateful=None,
288-
load_in_4bit=load_in_4bit,
289289
quantization_config=quantization_config,
290290
**kwargs,
291291
)
@@ -556,7 +556,6 @@ def _from_pretrained(
556556
from_onnx: bool = False,
557557
local_files_only: bool = False,
558558
load_in_8bit: bool = False,
559-
load_in_4bit: bool = False,
560559
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
561560
**kwargs,
562561
):
@@ -575,8 +574,10 @@ def _from_pretrained(
575574
local_files_only=local_files_only,
576575
)
577576

578-
if load_in_8bit and load_in_4bit:
579-
raise ValueError("Either load_in_8bit or load_in_4bit should be set to True.")
577+
if isinstance(quantization_config, dict):
578+
quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
579+
580+
load_in_4bit = quantization_config.bits == 4 if quantization_config else False
580581
model = cls.load_model(model_cache_path, load_in_8bit=False if load_in_4bit else load_in_8bit)
581582

582583
model_type = config.model_type.replace("_", "-")
@@ -594,7 +595,20 @@ def _from_pretrained(
594595
causal_model = init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
595596

596597
if load_in_4bit:
597-
compress_decoder_weights(causal_model, quantization_config)
598+
if not is_nncf_available():
599+
raise ImportError(
600+
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
601+
)
602+
from .quantization import _weight_only_quantization
603+
604+
default_config = _check_default_4bit_configs(config)
605+
606+
if default_config:
607+
logger.info(
608+
f"For the given model, we recommend the following `quantization_config` : {default_config}"
609+
)
610+
611+
_weight_only_quantization(causal_model, quantization_config)
598612
return causal_model
599613

600614

0 commit comments

Comments
 (0)