Skip to content

Commit 6888c0a

Browse files
[OV] Move data-driven quantization after model export for text-generation models (#721)
* Add quantization with dataset after model export for text-generation models * Tweak AWQ CLI interface * Additional checks * Fix * Trigger Build * Add AWQ description * Add trust remote code argument * Black * Add note about possibility of skipping AWQ * Removed saving to temporary directory; added core property handling for OVModelForCausalLM * Revert "Removed saving to temporary directory; added core property handling for OVModelForCausalLM" This reverts commit bcc4665. * Add saving intermediate weights in fp16; add removal of intermediate model if compression fails * Trigger checks * Trigger checks * Trigger checks * Fix test * Refactor applying quantization with dataset * Bring back quantization_config parameter * Trigger checks * Apply comment * Save tokenizer * Export CausalLM tokenizer * Remove unneccessary if * Remove extra variable * ruff * Ruff 2 * Introduce a separate function to tokenizer conversion * Black
1 parent f06f504 commit 6888c0a

File tree

5 files changed

+134
-74
lines changed

5 files changed

+134
-74
lines changed

optimum/commands/export/openvino.py

+63-24
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from typing import TYPE_CHECKING, Optional
2020

2121
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22+
from transformers.utils.quantization_config import QuantizationMethod
2223

2324
from ...exporters import TasksManager
2425
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
26+
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
2527
from ..base import BaseOptimumCLICommand, CommandInfo
2628

2729

@@ -128,6 +130,33 @@ def parse_args_openvino(parser: "ArgumentParser"):
128130
"compression is applied, they are compressed to INT8."
129131
),
130132
)
133+
optional_group.add_argument(
134+
"--awq",
135+
action="store_true",
136+
default=None,
137+
help=(
138+
"Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires "
139+
"additional time for tuning weights on a calibration dataset. To run AWQ, please also provide a dataset "
140+
"argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such "
141+
"case it will be skipped."
142+
),
143+
)
144+
optional_group.add_argument(
145+
"--sensitivity-metric",
146+
type=str,
147+
default=None,
148+
help=(
149+
"The sensitivity metric for assigning quantization precision to layers. Can be one of the following: "
150+
"['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', "
151+
"'max_activation_variance', 'mean_activation_magnitude']."
152+
),
153+
)
154+
optional_group.add_argument(
155+
"--num-samples",
156+
type=int,
157+
default=None,
158+
help="The maximum number of samples to take from the dataset for quantization.",
159+
)
131160
optional_group.add_argument(
132161
"--disable-stateful",
133162
action="store_true",
@@ -180,7 +209,7 @@ def parse_args(parser: "ArgumentParser"):
180209
return parse_args_openvino(parser)
181210

182211
def run(self):
183-
from ...exporters.openvino.__main__ import main_export
212+
from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers
184213
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig
185214

186215
if self.args.fp16:
@@ -208,6 +237,10 @@ def run(self):
208237
and self.args.group_size is None
209238
and self.args.sym is None
210239
and self.args.all_layers is None
240+
and self.args.dataset is None
241+
and self.args.num_samples is None
242+
and self.args.awq is None
243+
and self.args.sensitivity_metric is None
211244
and self.args.model in _DEFAULT_4BIT_CONFIGS
212245
):
213246
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
@@ -218,6 +251,10 @@ def run(self):
218251
"sym": self.args.sym or False,
219252
"group_size": -1 if is_int8 else self.args.group_size,
220253
"all_layers": None if is_int8 else self.args.all_layers,
254+
"dataset": self.args.dataset,
255+
"num_samples": self.args.num_samples,
256+
"quant_method": QuantizationMethod.AWQ if self.args.awq else None,
257+
"sensitivity_metric": self.args.sensitivity_metric,
221258
}
222259

223260
if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
@@ -226,7 +263,6 @@ def run(self):
226263
)
227264
quantization_config["sym"] = "asym" not in self.args.weight_format
228265
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
229-
quantization_config["dataset"] = self.args.dataset
230266
ov_config = OVConfig(quantization_config=quantization_config)
231267

232268
library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library)
@@ -240,12 +276,11 @@ def run(self):
240276
if self.args.convert_tokenizer:
241277
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")
242278

243-
if (
244-
library_name == "diffusers"
245-
and ov_config
246-
and ov_config.quantization_config
247-
and ov_config.quantization_config.dataset is not None
248-
):
279+
quantization_config = ov_config.quantization_config if ov_config else None
280+
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
281+
task = infer_task(self.args.task, self.args.model)
282+
283+
if library_name == "diffusers" and quantize_with_dataset:
249284
if not is_diffusers_available():
250285
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))
251286

@@ -270,25 +305,29 @@ def run(self):
270305
else:
271306
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
272307

273-
model = model_cls.from_pretrained(
274-
self.args.model, export=True, quantization_config=ov_config.quantization_config
308+
model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config)
309+
model.save_pretrained(self.args.output)
310+
if not self.args.disable_convert_tokenizer:
311+
maybe_convert_tokenizers(library_name, self.args.output, model)
312+
elif task.startswith("text-generation") and quantize_with_dataset:
313+
from optimum.intel import OVModelForCausalLM
314+
315+
# To quantize a text-generation model with a dataset, an instantiated OVModelForCausalLM is required
316+
model = OVModelForCausalLM.from_pretrained(
317+
self.args.model,
318+
export=True,
319+
quantization_config=quantization_config,
320+
stateful=not self.args.disable_stateful,
321+
trust_remote_code=self.args.trust_remote_code,
275322
)
276323
model.save_pretrained(self.args.output)
277324

278-
if self.args.disable_convert_tokenizer:
279-
return
280-
281-
# avoid import when using other exporters (IPEX, INC)
282-
from ...exporters.openvino.convert import export_tokenizer
283-
284-
output = Path(self.args.output)
285-
tokenizer = getattr(model, "tokenizer", None)
286-
if tokenizer is not None:
287-
export_tokenizer(tokenizer, output / "tokenizer")
288-
289-
tokenizer_2 = getattr(model, "tokenizer_2", None)
290-
if tokenizer_2 is not None:
291-
export_tokenizer(tokenizer_2, output / "tokenizer_2")
325+
maybe_save_preprocessors(self.args.model, self.args.output, trust_remote_code=self.args.trust_remote_code)
326+
if not self.args.disable_convert_tokenizer:
327+
preprocessors = maybe_load_preprocessors(
328+
self.args.model, trust_remote_code=self.args.trust_remote_code
329+
)
330+
maybe_convert_tokenizers(library_name, self.args.output, preprocessors=preprocessors)
292331
else:
293332
# TODO : add input shapes
294333
main_export(

optimum/exporters/openvino/__main__.py

+50-36
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@
4444
logger = logging.getLogger(__name__)
4545

4646

47+
def infer_task(task, model_name_or_path):
48+
task = TasksManager.map_from_synonym(task)
49+
if task == "auto":
50+
try:
51+
task = TasksManager.infer_task_from_model(model_name_or_path)
52+
except KeyError as e:
53+
raise KeyError(
54+
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
55+
)
56+
except RequestsConnectionError as e:
57+
raise RequestsConnectionError(
58+
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
59+
)
60+
return task
61+
62+
4763
def main_export(
4864
model_name_or_path: str,
4965
output: Union[str, Path],
@@ -174,7 +190,7 @@ def main_export(
174190
ov_config = OVConfig(quantization_config=q_config)
175191

176192
original_task = task
177-
task = TasksManager.map_from_synonym(task)
193+
task = infer_task(task, model_name_or_path)
178194
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
179195
library_name_is_not_provided = library_name is None
180196
library_name = TasksManager.infer_library_from_model(
@@ -188,18 +204,6 @@ def main_export(
188204
)
189205
library_name = "transformers"
190206

191-
if task == "auto":
192-
try:
193-
task = TasksManager.infer_task_from_model(model_name_or_path)
194-
except KeyError as e:
195-
raise KeyError(
196-
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
197-
)
198-
except RequestsConnectionError as e:
199-
raise RequestsConnectionError(
200-
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
201-
)
202-
203207
do_gptq_patching = False
204208
custom_architecture = False
205209
loading_kwargs = {}
@@ -360,36 +364,46 @@ class StoreAttr(object):
360364
**kwargs_shapes,
361365
)
362366

363-
# hide openvino import when using other exporters
364-
from optimum.exporters.openvino.convert import export_tokenizer
367+
if convert_tokenizer:
368+
maybe_convert_tokenizers(library_name, output, model, preprocessors)
369+
370+
# Unpatch modules after GPTQ export
371+
if do_gptq_patching:
372+
torch.cuda.is_available = orig_cuda_check
373+
GPTQQuantizer.post_init_model = orig_post_init_model
365374

366-
if convert_tokenizer and is_openvino_tokenizers_available():
367-
if library_name != "diffusers":
368-
tokenizer = next(
369-
(preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)),
370-
None,
371-
)
372375

373-
if tokenizer is not None:
376+
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None):
377+
"""
378+
Tries to convert tokenizers to OV format and export them to disk.
379+
380+
Arguments:
381+
library_name (`str`):
382+
The library name.
383+
output (`Path`):
384+
Path to save converted tokenizers to.
385+
model (`PreTrainedModel`, *optional*, defaults to None):
386+
Model instance.
387+
preprocessors (`Iterable`, *optional*, defaults to None):
388+
Iterable possibly containing tokenizers to be converted.
389+
"""
390+
from optimum.exporters.openvino.convert import export_tokenizer
391+
392+
if is_openvino_tokenizers_available():
393+
if library_name != "diffusers" and preprocessors:
394+
tokenizer = next(filter(lambda it: isinstance(it, PreTrainedTokenizerBase), preprocessors), None)
395+
if tokenizer:
374396
try:
375397
export_tokenizer(tokenizer, output)
376398
except Exception as exception:
377399
logger.warning(
378400
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
379401
f"models won't be generated. Exception: {exception}"
380402
)
381-
else:
382-
tokenizer = getattr(model, "tokenizer", None)
383-
if tokenizer is not None:
384-
export_tokenizer(tokenizer, output / "tokenizer")
385-
386-
tokenizer_2 = getattr(model, "tokenizer_2", None)
387-
if tokenizer_2 is not None:
388-
export_tokenizer(tokenizer_2, output / "tokenizer_2")
389-
elif convert_tokenizer and not is_openvino_tokenizers_available():
403+
elif model:
404+
for tokenizer_name in ("tokenizer", "tokenizer_2"):
405+
tokenizer = getattr(model, tokenizer_name, None)
406+
if tokenizer:
407+
export_tokenizer(tokenizer, output / tokenizer_name)
408+
else:
390409
logger.warning("Tokenizer won't be converted.")
391-
392-
# Unpatch modules after GPTQ export
393-
if do_gptq_patching:
394-
torch.cuda.is_available = orig_cuda_check
395-
GPTQQuantizer.post_init_model = orig_post_init_model

optimum/intel/openvino/modeling_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def fix_op_names_duplicates(model: openvino.runtime.Model):
135135
if file_name.suffix == ".onnx":
136136
model = fix_op_names_duplicates(model) # should be called during model conversion to IR
137137

138+
# TODO: remove this way of applying quantization; instead apply it after instance of OVModel* is loaded
138139
if quantization_config:
139140
if not is_nncf_available():
140141
raise ImportError(

optimum/intel/openvino/modeling_decoder.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -752,17 +752,7 @@ def _from_pretrained(
752752
local_files_only=local_files_only,
753753
)
754754

755-
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
756-
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
757-
758-
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
759-
760-
load_in_4bit = quantization_config.bits == 4 if quantization_config else False
761-
762-
model = cls.load_model(
763-
model_cache_path,
764-
quantization_config=None if load_in_4bit else quantization_config,
765-
)
755+
model = cls.load_model(model_cache_path)
766756

767757
model_type = config.model_type.replace("_", "-")
768758
if model_type == "bloom":
@@ -772,7 +762,12 @@ def _from_pretrained(
772762
else:
773763
init_cls = cls
774764

775-
enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
765+
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
766+
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
767+
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
768+
769+
enable_compilation = kwargs.pop("compile", True) and not quantization_config
770+
776771
try:
777772
generation_config = GenerationConfig.from_pretrained(
778773
model_id,
@@ -785,6 +780,7 @@ def _from_pretrained(
785780
kwargs["generation_config"] = generation_config
786781
except Exception:
787782
pass
783+
788784
causal_model = init_cls(
789785
model=model,
790786
config=config,
@@ -794,7 +790,7 @@ def _from_pretrained(
794790
**kwargs,
795791
)
796792

797-
if load_in_4bit:
793+
if quantization_config:
798794
if not is_nncf_available():
799795
raise ImportError(
800796
"Quantization of the weights requires nncf, please install it with `pip install nncf`"

tests/openvino/test_exporters_cli.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ class OVCLIExportTestCase(unittest.TestCase):
8989
("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86),
9090
("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86),
9191
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32),
92+
(
93+
"text-generation-with-past",
94+
"llama_awq",
95+
"int4 --ratio 1.0 --sym --group-size 16 --awq --dataset wikitext2 --num-samples 100 "
96+
"--sensitivity-metric max_activation_variance",
97+
4,
98+
28,
99+
),
92100
]
93101

94102
def _openvino_export(
@@ -197,17 +205,19 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
197205
@parameterized.expand(TEST_4BIT_CONFIGURATONS)
198206
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int):
199207
with TemporaryDirectory() as tmpdir:
200-
subprocess.run(
208+
result = subprocess.run(
201209
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
202210
shell=True,
203211
check=True,
212+
capture_output=True,
204213
)
205214
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
206215
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)
207216

208217
_, num_int8, num_int4 = get_num_quantized_nodes(model)
209218
self.assertEqual(expected_int8, num_int8)
210219
self.assertEqual(expected_int4, num_int4)
220+
self.assertTrue("--awq" not in option or b"Applying AWQ" in result.stdout)
211221

212222
def test_exporters_cli_help(self):
213223
subprocess.run(

0 commit comments

Comments
 (0)