Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OV] Move data-driven quantization after model export for text-generation models #721

Merged
merged 29 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
56878bb
Add quantization with dataset after model export for text-generation …
nikita-savelyevv May 21, 2024
013a0f6
Tweak AWQ CLI interface
nikita-savelyevv May 21, 2024
c566ccc
Additional checks
nikita-savelyevv May 21, 2024
0a8fba0
Fix
nikita-savelyevv May 21, 2024
6dbb4fe
Trigger Build
nikita-savelyevv May 21, 2024
3722624
Add AWQ description
nikita-savelyevv May 22, 2024
dee582d
Add trust remote code argument
nikita-savelyevv May 22, 2024
a44c096
Black
nikita-savelyevv May 22, 2024
12dc672
Add note about possibility of skipping AWQ
nikita-savelyevv May 22, 2024
bcc4665
Removed saving to temporary directory; added core property handling f…
nikita-savelyevv May 23, 2024
40058da
Revert "Removed saving to temporary directory; added core property ha…
nikita-savelyevv May 23, 2024
0886f7e
Add saving intermediate weights in fp16; add removal of intermediate …
nikita-savelyevv May 23, 2024
ee9b1b7
Trigger checks
nikita-savelyevv May 23, 2024
cb57068
Trigger checks
nikita-savelyevv May 24, 2024
ee0b67f
Trigger checks
nikita-savelyevv May 28, 2024
cacbb36
Fix test
nikita-savelyevv May 31, 2024
814d96c
Refactor applying quantization with dataset
nikita-savelyevv May 31, 2024
d8017ab
Bring back quantization_config parameter
nikita-savelyevv May 31, 2024
24272dc
Trigger checks
nikita-savelyevv May 31, 2024
40b0e29
Apply comment
nikita-savelyevv Jun 3, 2024
f54aa40
Save tokenizer
nikita-savelyevv Jun 4, 2024
96bed29
Export CausalLM tokenizer
nikita-savelyevv Jun 4, 2024
a6005ad
Remove unneccessary if
nikita-savelyevv Jun 4, 2024
e311916
Remove extra variable
nikita-savelyevv Jun 4, 2024
fc44214
ruff
nikita-savelyevv Jun 4, 2024
709085b
Ruff 2
nikita-savelyevv Jun 4, 2024
a2084d9
Introduce a separate function to tokenizer conversion
nikita-savelyevv Jun 5, 2024
e8cc0e9
Black
nikita-savelyevv Jun 5, 2024
6815773
Merge branch 'main' into cli-awq
echarlaix Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from typing import TYPE_CHECKING, Optional

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from transformers.utils.quantization_config import QuantizationMethod

from ...exporters import TasksManager
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ..base import BaseOptimumCLICommand, CommandInfo


Expand Down Expand Up @@ -128,6 +130,33 @@ def parse_args_openvino(parser: "ArgumentParser"):
"compression is applied, they are compressed to INT8."
),
)
optional_group.add_argument(
"--awq",
action="store_true",
default=None,
help=(
"Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires "
"additional time for tuning weights on a calibration dataset. To run AWQ, please also provide a dataset "
"argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such "
"case it will be skipped."
),
)
optional_group.add_argument(
"--sensitivity-metric",
type=str,
default=None,
help=(
"The sensitivity metric for assigning quantization precision to layers. Can be one of the following: "
"['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', "
"'max_activation_variance', 'mean_activation_magnitude']."
),
)
optional_group.add_argument(
"--num-samples",
type=int,
default=None,
help="The maximum number of samples to take from the dataset for quantization.",
)
optional_group.add_argument(
"--disable-stateful",
action="store_true",
Expand Down Expand Up @@ -180,7 +209,7 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_openvino(parser)

def run(self):
from ...exporters.openvino.__main__ import main_export
from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig

if self.args.fp16:
Expand Down Expand Up @@ -208,6 +237,10 @@ def run(self):
and self.args.group_size is None
and self.args.sym is None
and self.args.all_layers is None
and self.args.dataset is None
and self.args.num_samples is None
and self.args.awq is None
and self.args.sensitivity_metric is None
and self.args.model in _DEFAULT_4BIT_CONFIGS
):
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
Expand All @@ -218,6 +251,10 @@ def run(self):
"sym": self.args.sym or False,
"group_size": -1 if is_int8 else self.args.group_size,
"all_layers": None if is_int8 else self.args.all_layers,
"dataset": self.args.dataset,
"num_samples": self.args.num_samples,
"quant_method": QuantizationMethod.AWQ if self.args.awq else None,
"sensitivity_metric": self.args.sensitivity_metric,
}

if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
Expand All @@ -226,7 +263,6 @@ def run(self):
)
quantization_config["sym"] = "asym" not in self.args.weight_format
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
quantization_config["dataset"] = self.args.dataset
ov_config = OVConfig(quantization_config=quantization_config)

library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library)
Expand All @@ -240,12 +276,11 @@ def run(self):
if self.args.convert_tokenizer:
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")

if (
library_name == "diffusers"
and ov_config
and ov_config.quantization_config
and ov_config.quantization_config.dataset is not None
):
quantization_config = ov_config.quantization_config if ov_config else None
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
task = infer_task(self.args.task, self.args.model)

if library_name == "diffusers" and quantize_with_dataset:
if not is_diffusers_available():
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))

Expand All @@ -270,25 +305,29 @@ def run(self):
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")

model = model_cls.from_pretrained(
self.args.model, export=True, quantization_config=ov_config.quantization_config
model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config)
model.save_pretrained(self.args.output)
if not self.args.disable_convert_tokenizer:
maybe_convert_tokenizers(library_name, self.args.output, model)
elif task.startswith("text-generation") and quantize_with_dataset:
from optimum.intel import OVModelForCausalLM

# To quantize a text-generation model with a dataset, an instantiated OVModelForCausalLM is required
model = OVModelForCausalLM.from_pretrained(
self.args.model,
export=True,
quantization_config=quantization_config,
stateful=not self.args.disable_stateful,
trust_remote_code=self.args.trust_remote_code,
)
model.save_pretrained(self.args.output)

if self.args.disable_convert_tokenizer:
return

# avoid import when using other exporters (IPEX, INC)
from ...exporters.openvino.convert import export_tokenizer

output = Path(self.args.output)
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output / "tokenizer_2")
maybe_save_preprocessors(self.args.model, self.args.output, trust_remote_code=self.args.trust_remote_code)
if not self.args.disable_convert_tokenizer:
preprocessors = maybe_load_preprocessors(
self.args.model, trust_remote_code=self.args.trust_remote_code
)
maybe_convert_tokenizers(library_name, self.args.output, preprocessors=preprocessors)
else:
# TODO : add input shapes
main_export(
Expand Down
86 changes: 50 additions & 36 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@
logger = logging.getLogger(__name__)


def infer_task(task, model_name_or_path):
task = TasksManager.map_from_synonym(task)
if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
except RequestsConnectionError as e:
raise RequestsConnectionError(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
return task


def main_export(
model_name_or_path: str,
output: Union[str, Path],
Expand Down Expand Up @@ -174,7 +190,7 @@ def main_export(
ov_config = OVConfig(quantization_config=q_config)

original_task = task
task = TasksManager.map_from_synonym(task)
task = infer_task(task, model_name_or_path)
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
library_name_is_not_provided = library_name is None
library_name = TasksManager.infer_library_from_model(
Expand All @@ -188,18 +204,6 @@ def main_export(
)
library_name = "transformers"

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
except RequestsConnectionError as e:
raise RequestsConnectionError(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

do_gptq_patching = False
custom_architecture = False
loading_kwargs = {}
Expand Down Expand Up @@ -360,36 +364,46 @@ class StoreAttr(object):
**kwargs_shapes,
)

# hide openvino import when using other exporters
from optimum.exporters.openvino.convert import export_tokenizer
if convert_tokenizer:
maybe_convert_tokenizers(library_name, output, model, preprocessors)

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model

if convert_tokenizer and is_openvino_tokenizers_available():
if library_name != "diffusers":
tokenizer = next(
(preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)),
None,
)

if tokenizer is not None:
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None):
"""
Tries to convert tokenizers to OV format and export them to disk.

Arguments:
library_name (`str`):
The library name.
output (`Path`):
Path to save converted tokenizers to.
model (`PreTrainedModel`, *optional*, defaults to None):
Model instance.
preprocessors (`Iterable`, *optional*, defaults to None):
Iterable possibly containing tokenizers to be converted.
"""
from optimum.exporters.openvino.convert import export_tokenizer

if is_openvino_tokenizers_available():
if library_name != "diffusers" and preprocessors:
tokenizer = next(filter(lambda it: isinstance(it, PreTrainedTokenizerBase), preprocessors), None)
if tokenizer:
try:
export_tokenizer(tokenizer, output)
except Exception as exception:
logger.warning(
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
f"models won't be generated. Exception: {exception}"
)
else:
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output / "tokenizer_2")
elif convert_tokenizer and not is_openvino_tokenizers_available():
elif model:
for tokenizer_name in ("tokenizer", "tokenizer_2"):
tokenizer = getattr(model, tokenizer_name, None)
if tokenizer:
export_tokenizer(tokenizer, output / tokenizer_name)
else:
logger.warning("Tokenizer won't be converted.")

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def fix_op_names_duplicates(model: openvino.runtime.Model):
if file_name.suffix == ".onnx":
model = fix_op_names_duplicates(model) # should be called during model conversion to IR

# TODO: remove this way of applying quantization; instead apply it after instance of OVModel* is loaded
if quantization_config:
if not is_nncf_available():
raise ImportError(
Expand Down
20 changes: 7 additions & 13 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,17 +741,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)

quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

load_in_4bit = quantization_config.bits == 4 if quantization_config else False

model = cls.load_model(
model_cache_path,
quantization_config=None if load_in_4bit else quantization_config,
)
model = cls.load_model(model_cache_path)

model_type = config.model_type.replace("_", "-")
if model_type == "bloom":
Expand All @@ -761,7 +751,11 @@ def _from_pretrained(
else:
init_cls = cls

enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

enable_compilation = kwargs.pop("compile", True) and not quantization_config
causal_model = init_cls(
model=model,
config=config,
Expand All @@ -771,7 +765,7 @@ def _from_pretrained(
**kwargs,
)

if load_in_4bit:
if quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
Expand Down
12 changes: 11 additions & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class OVCLIExportTestCase(unittest.TestCase):
("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86),
("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86),
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32),
(
"text-generation-with-past",
"llama_awq",
"int4 --ratio 1.0 --sym --group-size 16 --awq --dataset wikitext2 --num-samples 100 "
"--sensitivity-metric max_activation_variance",
4,
28,
),
]

def _openvino_export(
Expand Down Expand Up @@ -197,17 +205,19 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
@parameterized.expand(TEST_4BIT_CONFIGURATONS)
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int):
with TemporaryDirectory() as tmpdir:
subprocess.run(
result = subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
shell=True,
check=True,
capture_output=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

_, num_int8, num_int4 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_int4, num_int4)
self.assertTrue("--awq" not in option or b"Applying AWQ" in result.stdout)

def test_exporters_cli_help(self):
subprocess.run(
Expand Down
Loading