-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OV] Move data-driven quantization after model export for text-generation models #721
Changes from 15 commits
56878bb
013a0f6
c566ccc
0a8fba0
6dbb4fe
3722624
dee582d
a44c096
12dc672
bcc4665
40058da
0886f7e
ee9b1b7
cb57068
ee0b67f
cacbb36
814d96c
d8017ab
24272dc
40b0e29
f54aa40
96bed29
a6005ad
e311916
fc44214
709085b
a2084d9
e8cc0e9
6815773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,9 @@ | |
"""Defines the command line for the export with OpenVINO.""" | ||
|
||
import logging | ||
import shutil | ||
import sys | ||
import tempfile | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
|
@@ -128,6 +130,33 @@ def parse_args_openvino(parser: "ArgumentParser"): | |
"compression is applied, they are compressed to INT8." | ||
), | ||
) | ||
optional_group.add_argument( | ||
"--awq", | ||
action="store_true", | ||
default=None, | ||
help=( | ||
"Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires " | ||
"additional time for tuning weights on a calibration dataset. To run AWQ, please also provide a dataset " | ||
"argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such " | ||
"case it will be skipped." | ||
), | ||
) | ||
optional_group.add_argument( | ||
"--sensitivity-metric", | ||
type=str, | ||
default=None, | ||
help=( | ||
"The sensitivity metric for assigning quantization precision to layers. Can be one of the following: " | ||
"['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', " | ||
"'max_activation_variance', 'mean_activation_magnitude']." | ||
), | ||
) | ||
optional_group.add_argument( | ||
"--num-samples", | ||
type=int, | ||
default=None, | ||
help="The maximum number of samples to take from the dataset for quantization.", | ||
) | ||
optional_group.add_argument( | ||
"--disable-stateful", | ||
action="store_true", | ||
|
@@ -180,7 +209,7 @@ def parse_args(parser: "ArgumentParser"): | |
return parse_args_openvino(parser) | ||
|
||
def run(self): | ||
from ...exporters.openvino.__main__ import main_export | ||
from ...exporters.openvino.__main__ import infer_task, main_export | ||
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig | ||
|
||
if self.args.fp16: | ||
|
@@ -208,6 +237,10 @@ def run(self): | |
and self.args.group_size is None | ||
and self.args.sym is None | ||
and self.args.all_layers is None | ||
and self.args.dataset is None | ||
and self.args.num_samples is None | ||
and self.args.awq is None | ||
and self.args.sensitivity_metric is None | ||
and self.args.model in _DEFAULT_4BIT_CONFIGS | ||
): | ||
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model] | ||
|
@@ -218,6 +251,10 @@ def run(self): | |
"sym": self.args.sym or False, | ||
"group_size": -1 if is_int8 else self.args.group_size, | ||
"all_layers": None if is_int8 else self.args.all_layers, | ||
"dataset": self.args.dataset, | ||
"num_samples": self.args.num_samples, | ||
"quant_method": "awq" if self.args.awq else None, | ||
"sensitivity_metric": self.args.sensitivity_metric, | ||
} | ||
|
||
if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: | ||
|
@@ -226,7 +263,6 @@ def run(self): | |
) | ||
quantization_config["sym"] = "asym" not in self.args.weight_format | ||
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64 | ||
quantization_config["dataset"] = self.args.dataset | ||
ov_config = OVConfig(quantization_config=quantization_config) | ||
|
||
library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library) | ||
|
@@ -290,6 +326,22 @@ def run(self): | |
if tokenizer_2 is not None: | ||
export_tokenizer(tokenizer_2, output / "tokenizer_2") | ||
else: | ||
task = infer_task(self.args.task, self.args.model) | ||
quantization_config = ov_config.quantization_config if ov_config else None | ||
quantize_after_export = ( | ||
task.startswith("text-generation") | ||
and quantization_config | ||
and hasattr(quantization_config, "dataset") | ||
and quantization_config.dataset is not None | ||
) | ||
if quantize_after_export: | ||
# In order to quantize a text-generation model with a dataset, an instance of OVModelForCausalLM is | ||
# required. That's why the quantization is skipped during export and applied explicitly after export. | ||
ov_config.quantization_config = None | ||
# Export intermediate model with f16 weights to save up disk space | ||
original_dtype_value = ov_config.dtype | ||
ov_config.dtype = "fp16" | ||
|
||
# TODO : add input shapes | ||
main_export( | ||
model_name_or_path=self.args.model, | ||
|
@@ -305,3 +357,25 @@ def run(self): | |
library_name=library_name, | ||
# **input_shapes, | ||
) | ||
|
||
if quantize_after_export: | ||
try: | ||
from optimum.intel import OVModelForCausalLM, OVQuantizer | ||
|
||
ov_config.dtype = original_dtype_value | ||
model = OVModelForCausalLM.from_pretrained( | ||
self.args.output, trust_remote_code=self.args.trust_remote_code | ||
) | ||
quantizer = OVQuantizer(model) | ||
quantization_config.tokenizer = quantization_config.tokenizer or str(self.args.output) | ||
# TODO: set save_directory=self.args.output once OV is updated to 2024.3 | ||
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config)) | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
model.save_pretrained(temp_dir) | ||
ov_config.save_pretrained(self.args.output) | ||
shutil.copy(f"{temp_dir}/openvino_model.xml", f"{self.args.output}/openvino_model.xml") | ||
shutil.copy(f"{temp_dir}/openvino_model.bin", f"{self.args.output}/openvino_model.bin") | ||
except Exception as e: | ||
# Delete non-compressed model if compression failed for some reason | ||
shutil.rmtree(str(self.args.output)) | ||
raise e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not exporting + applying quantization using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If possible then it could actually make sense to do this for all models (as it's already the case for SD models) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! This is indeed better.
It would be more convenient from code maintenance side. But compared to calling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if it should be
QuantizationMethod.AWQ
instead of "awq" or if the configuration takes care of thisoptimum-intel/optimum/intel/openvino/quantization.py
Line 822 in 52875b9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitated to do it this way because it would require to introduce dependency on
transformers
in this file in order to importQuantizationMethod
. But now I see thattransformers
is a general requirement ofoptimum
so it should be fine I guess.