Skip to content

Commit 56878bb

Browse files
Add quantization with dataset after model export for text-generation models
1 parent 7114900 commit 56878bb

File tree

3 files changed

+89
-16
lines changed

3 files changed

+89
-16
lines changed

optimum/commands/export/openvino.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import sys
18+
import tempfile
1819
from pathlib import Path
1920
from typing import TYPE_CHECKING, Optional
2021

@@ -128,6 +129,29 @@ def parse_args_openvino(parser: "ArgumentParser"):
128129
"compression is applied, they are compressed to INT8."
129130
),
130131
)
132+
optional_group.add_argument(
133+
"--quant-method",
134+
type=str,
135+
default=None,
136+
choices=["default", "awq", "hybrid"],
137+
help=("The quantization method to apply. Can be one of the following: ['default', 'awq', 'hybrid']."),
138+
)
139+
optional_group.add_argument(
140+
"--sensitivity-metric",
141+
type=str,
142+
default=None,
143+
help=(
144+
"The sensitivity metric for assigning quantization precision to layers. Can be one of the following: "
145+
"['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', "
146+
"'max_activation_variance', 'mean_activation_magnitude']."
147+
),
148+
)
149+
optional_group.add_argument(
150+
"--num-samples",
151+
type=int,
152+
default=None,
153+
help=("The maximum number of samples composing the calibration dataset for quantization."),
154+
)
131155
optional_group.add_argument(
132156
"--disable-stateful",
133157
action="store_true",
@@ -180,7 +204,7 @@ def parse_args(parser: "ArgumentParser"):
180204
return parse_args_openvino(parser)
181205

182206
def run(self):
183-
from ...exporters.openvino.__main__ import main_export
207+
from ...exporters.openvino.__main__ import main_export, infer_task
184208
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig
185209

186210
if self.args.fp16:
@@ -208,6 +232,9 @@ def run(self):
208232
and self.args.group_size is None
209233
and self.args.sym is None
210234
and self.args.all_layers is None
235+
and self.args.dataset is None
236+
and self.args.quant_method is None
237+
and self.args.sensitivity_metric is None
211238
and self.args.model in _DEFAULT_4BIT_CONFIGS
212239
):
213240
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
@@ -218,6 +245,10 @@ def run(self):
218245
"sym": self.args.sym or False,
219246
"group_size": -1 if is_int8 else self.args.group_size,
220247
"all_layers": None if is_int8 else self.args.all_layers,
248+
"dataset": self.args.dataset,
249+
"num_samples": self.args.num_samples,
250+
"quant_method": self.args.quant_method,
251+
"sensitivity_metric": self.args.sensitivity_metric,
221252
}
222253

223254
if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
@@ -226,7 +257,6 @@ def run(self):
226257
)
227258
quantization_config["sym"] = "asym" not in self.args.weight_format
228259
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
229-
quantization_config["dataset"] = self.args.dataset
230260
ov_config = OVConfig(quantization_config=quantization_config)
231261

232262
library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library)
@@ -290,6 +320,19 @@ def run(self):
290320
if tokenizer_2 is not None:
291321
export_tokenizer(tokenizer_2, output / "tokenizer_2")
292322
else:
323+
task = infer_task(self.args.task, self.args.model)
324+
quantization_config = ov_config.quantization_config
325+
quantize_after_export = (
326+
task.startswith("text-generation")
327+
and quantization_config is not None
328+
and hasattr(quantization_config, "dataset")
329+
and quantization_config.dataset is not None
330+
)
331+
if quantize_after_export:
332+
# In order to quantize a text-generation model with a dataset, an instance of OVModelForCausalLM is
333+
# required. That's why the quantization is skipped during export and applied explicitly after export.
334+
ov_config.quantization_config = None
335+
293336
# TODO : add input shapes
294337
main_export(
295338
model_name_or_path=self.args.model,
@@ -305,3 +348,19 @@ def run(self):
305348
library_name=library_name,
306349
# **input_shapes,
307350
)
351+
352+
if quantize_after_export:
353+
from optimum.intel import OVModelForCausalLM, OVQuantizer
354+
355+
model = OVModelForCausalLM.from_pretrained(self.args.output)
356+
quantizer = OVQuantizer(model)
357+
quantization_config.tokenizer = quantization_config.tokenizer or str(self.args.output)
358+
# TODO: set save_directory=self.args.output once OV is updated to 2024.3
359+
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config))
360+
with tempfile.TemporaryDirectory() as temp_dir:
361+
import shutil
362+
363+
model.save_pretrained(temp_dir)
364+
ov_config.save_pretrained(self.args.output)
365+
shutil.copy(f"{temp_dir}/openvino_model.xml", f"{self.args.output}/openvino_model.xml")
366+
shutil.copy(f"{temp_dir}/openvino_model.bin", f"{self.args.output}/openvino_model.bin")

optimum/exporters/openvino/__main__.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@
4444
logger = logging.getLogger(__name__)
4545

4646

47+
def infer_task(task, model_name_or_path):
48+
task = TasksManager.map_from_synonym(task)
49+
if task == "auto":
50+
try:
51+
task = TasksManager.infer_task_from_model(model_name_or_path)
52+
except KeyError as e:
53+
raise KeyError(
54+
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
55+
)
56+
except RequestsConnectionError as e:
57+
raise RequestsConnectionError(
58+
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
59+
)
60+
return task
61+
62+
4763
def main_export(
4864
model_name_or_path: str,
4965
output: Union[str, Path],
@@ -174,7 +190,7 @@ def main_export(
174190
ov_config = OVConfig(quantization_config=q_config)
175191

176192
original_task = task
177-
task = TasksManager.map_from_synonym(task)
193+
task = infer_task(task, model_name_or_path)
178194
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
179195
library_name_is_not_provided = library_name is None
180196
library_name = TasksManager.infer_library_from_model(
@@ -188,18 +204,6 @@ def main_export(
188204
)
189205
library_name = "transformers"
190206

191-
if task == "auto":
192-
try:
193-
task = TasksManager.infer_task_from_model(model_name_or_path)
194-
except KeyError as e:
195-
raise KeyError(
196-
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
197-
)
198-
except RequestsConnectionError as e:
199-
raise RequestsConnectionError(
200-
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
201-
)
202-
203207
do_gptq_patching = False
204208
custom_architecture = False
205209
loading_kwargs = {}

tests/openvino/test_exporters_cli.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ class OVCLIExportTestCase(unittest.TestCase):
8989
("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86),
9090
("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86),
9191
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32),
92+
(
93+
"text-generation-with-past",
94+
"llama_awq",
95+
"int4 --ratio 1.0 --sym --group-size 16 --quant-method awq --dataset wikitext2 --num-samples 100 "
96+
"--sensitivity-metric max_activation_variance",
97+
4,
98+
28,
99+
),
92100
]
93101

94102
def _openvino_export(
@@ -197,17 +205,19 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
197205
@parameterized.expand(TEST_4BIT_CONFIGURATONS)
198206
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int):
199207
with TemporaryDirectory() as tmpdir:
200-
subprocess.run(
208+
result = subprocess.run(
201209
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
202210
shell=True,
203211
check=True,
212+
capture_output=True,
204213
)
205214
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
206215
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)
207216

208217
_, num_int8, num_int4 = get_num_quantized_nodes(model)
209218
self.assertEqual(expected_int8, num_int8)
210219
self.assertEqual(expected_int4, num_int4)
220+
self.assertTrue("--quant-method awq" not in option or b"Applying AWQ" in result.stdout)
211221

212222
def test_exporters_cli_help(self):
213223
subprocess.run(

0 commit comments

Comments
 (0)