Skip to content

Commit 85c925a

Browse files
authored
Merge branch 'main' into del-convert-tokenizer-flag
2 parents 80d4c1d + 4651ac2 commit 85c925a

File tree

15 files changed

+411
-256
lines changed

15 files changed

+411
-256
lines changed

.github/workflows/test_inc.yml

+10-4
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,17 @@ jobs:
3232
python -m pip install --upgrade pip
3333
pip install cmake
3434
pip install py-cpuinfo
35-
pip install torch==2.1.0 torchaudio==2.1.0 torchvision==0.16 --extra-index-url https://download.pytorch.org/whl/cpu
3635
pip install .[neural-compressor,diffusers,tests]
37-
pip install intel-extension-for-pytorch==2.1.100
38-
pip install intel-extension-for-transformers==1.3.2
36+
pip install intel-extension-for-transformers
3937
pip install peft
38+
4039
- name: Test with Pytest
4140
run: |
42-
pytest tests/neural_compressor/
41+
pytest tests/neural_compressor/ --ignore tests/neural_compressor/test_ipex.py --durations=0
42+
- name: Test IPEX
43+
run: |
44+
pip uninstall -y intel-extension-for-transformers
45+
pip install torch==2.1.0 torchaudio==2.1.0 torchvision==0.16 --extra-index-url https://download.pytorch.org/whl/cpu
46+
pip install intel-extension-for-pytorch==2.1.100
47+
pytest tests/neural_compressor/test_ipex.py
48+

README.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,18 @@ It is possible to export your model to the [OpenVINO IR](https://docs.openvino.a
7878
optimum-cli export openvino --model gpt2 ov_model
7979
```
8080

81-
You can also apply 8-bit weight-only quantization when exporting your model : the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.
81+
You can also apply 8-bit weight-only quantization when exporting your model : the model linear, embedding and convolution weights will be quantized to INT8, the activations will be kept in floating point precision.
8282

8383
```plain
8484
optimum-cli export openvino --model gpt2 --weight-format int8 ov_model
8585
```
8686

87+
Quantization in hybrid mode can be applied to Stable Diffusion pipeline during model export. This involves applying hybrid post-training quantization to the UNet model and weight-only quantization for the rest of the pipeline components. In the hybrid mode, weights in MatMul and Embedding layers are quantized, as well as activations of other layers.
88+
89+
```plain
90+
optimum-cli export openvino --model stabilityai/stable-diffusion-2-1 --dataset conceptual_captions --weight-format int8 ov_model
91+
```
92+
8793
To apply quantization on both weights and activations, you can find more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov).
8894

8995
#### Inference:

examples/neural_compressor/language-modeling/run_clm.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464

6565

6666
if is_intel_extension_for_transformers_available():
67-
from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig
68-
67+
from intel_extension_for_transformers.transformers.utils.config import GPTQConfig, RtnConfig
6968

7069
os.environ["CUDA_VISIBLE_DEVICES"] = ""
7170

@@ -227,8 +226,9 @@ class OptimizationArguments:
227226
metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."},
228227
)
229228
quantization_methodology: str = field(
230-
default="RTN",
231-
metadata={"help": "Quantization methodology for weight only quantization. Choose from 'RTN' and 'GPTQ'."},
229+
choices=["rtn", "gptq"],
230+
default="rtn",
231+
metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."},
232232
)
233233
damp_percent: float = field(
234234
default=0.01,
@@ -662,22 +662,23 @@ def compute_metrics(eval_preds):
662662
raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("WeightOnly quantization"))
663663
if optim_args.apply_pruning or optim_args.apply_distillation:
664664
raise ValueError("Weight only quantization and pruning or distillation cannot be combined.")
665-
if optim_args.quantization_methodology == "GPTQ":
666-
algorithm_args = {
667-
"act_order": False,
668-
"percdamp": optim_args.damp_percent,
669-
"block_size": optim_args.gptq_block_size,
670-
"nsamples": optim_args.num_calibration_samples,
671-
"use_max_length": optim_args.use_max_length,
672-
"pad_max_length": optim_args.pad_max_length,
673-
}
674-
quantization_config = WeightOnlyQuantConfig(
675-
weight_dtype=optim_args.weight_dtype,
676-
group_size=optim_args.group_size,
677-
scheme=optim_args.weight_only_scheme,
678-
algorithm=optim_args.quantization_methodology,
679-
algorithm_args=algorithm_args if optim_args.quantization_methodology == "GPTQ" else None,
680-
)
665+
666+
algorithm_args = {
667+
"weight_dtype": optim_args.weight_dtype,
668+
"sym": optim_args.weight_only_scheme == "sym",
669+
"group_size": optim_args.group_size,
670+
}
671+
672+
if optim_args.quantization_methodology == "gptq":
673+
quantization_config = GPTQConfig(
674+
damp_percent=optim_args.damp_percent,
675+
nsamples=optim_args.num_calibration_samples,
676+
blocksize=optim_args.gptq_block_size,
677+
**algorithm_args,
678+
)
679+
else:
680+
quantization_config = RtnConfig(**algorithm_args)
681+
681682
else:
682683
quantization_config = PostTrainingQuantConfig(
683684
approach=optim_args.quantization_approach, recipes=recipes

optimum/commands/export/openvino.py

+70-18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING, Optional
2020

2121
from ...exporters import TasksManager
22+
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
2223
from ..base import BaseOptimumCLICommand, CommandInfo
2324

2425

@@ -104,6 +105,16 @@ def parse_args_openvino(parser: "ArgumentParser"):
104105
default=None,
105106
help=("The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization."),
106107
)
108+
optional_group.add_argument(
109+
"--dataset",
110+
type=str,
111+
default=None,
112+
help=(
113+
"The dataset used for data-aware compression or quantization with NNCF. "
114+
"You can use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs "
115+
"or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models."
116+
),
117+
)
107118
optional_group.add_argument(
108119
"--disable-stateful",
109120
action="store_true",
@@ -200,23 +211,64 @@ def run(self):
200211
)
201212
quantization_config["sym"] = "asym" not in self.args.weight_format
202213
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
214+
quantization_config["dataset"] = self.args.dataset
203215
ov_config = OVConfig(quantization_config=quantization_config)
204216

205-
if self.args.convert_tokenizer:
206-
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")
207-
208-
# TODO : add input shapes
209-
main_export(
210-
model_name_or_path=self.args.model,
211-
output=self.args.output,
212-
task=self.args.task,
213-
framework=self.args.framework,
214-
cache_dir=self.args.cache_dir,
215-
trust_remote_code=self.args.trust_remote_code,
216-
pad_token_id=self.args.pad_token_id,
217-
ov_config=ov_config,
218-
stateful=not self.args.disable_stateful,
219-
convert_tokenizer=not self.args.disable_convert_tokenizer,
220-
library_name=self.args.library
221-
# **input_shapes,
222-
)
217+
library_name = TasksManager.infer_library_from_model(self.args.model)
218+
219+
if (
220+
library_name == "diffusers"
221+
and ov_config
222+
and ov_config.quantization_config
223+
and ov_config.quantization_config.dataset is not None
224+
):
225+
if not is_diffusers_available():
226+
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))
227+
228+
from diffusers import DiffusionPipeline
229+
230+
diffusers_config = DiffusionPipeline.load_config(self.args.model)
231+
class_name = diffusers_config.get("_class_name", None)
232+
233+
if class_name == "LatentConsistencyModelPipeline":
234+
from optimum.intel import OVLatentConsistencyModelPipeline
235+
236+
model_cls = OVLatentConsistencyModelPipeline
237+
238+
elif class_name == "StableDiffusionXLPipeline":
239+
from optimum.intel import OVStableDiffusionXLPipeline
240+
241+
model_cls = OVStableDiffusionXLPipeline
242+
elif class_name == "StableDiffusionPipeline":
243+
from optimum.intel import OVStableDiffusionPipeline
244+
245+
model_cls = OVStableDiffusionPipeline
246+
else:
247+
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
248+
249+
model = model_cls.from_pretrained(
250+
self.args.model, export=True, quantization_config=ov_config.quantization_config
251+
)
252+
model.save_pretrained(self.args.output)
253+
254+
else:
255+
if self.args.convert_tokenizer:
256+
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")
257+
258+
# TODO : add input shapes
259+
main_export(
260+
model_name_or_path=self.args.model,
261+
output=self.args.output,
262+
task=self.args.task,
263+
framework=self.args.framework,
264+
cache_dir=self.args.cache_dir,
265+
trust_remote_code=self.args.trust_remote_code,
266+
pad_token_id=self.args.pad_token_id,
267+
ov_config=ov_config,
268+
stateful=not self.args.disable_stateful,
269+
convert_tokenizer=not self.args.disable_convert_tokenizer,
270+
library_name=library_name,
271+
# **input_shapes,
272+
)
273+
274+

optimum/exporters/openvino/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def main_export(
7676
model_name_or_path (`str`):
7777
Model ID on huggingface.co or path on disk to the model repository to export.
7878
output (`Union[str, Path]`):
79-
Path indicating the directory where to store the generated ONNX model.
79+
Path indicating the directory where to store the generated OpenVINO model.
8080
8181
> Optional parameters
8282

optimum/intel/neural_compressor/modeling_base.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@
6767
"""
6868

6969

70-
if is_intel_extension_for_transformers_available():
71-
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM as ITREX_WOQ_MODEL
72-
from intel_extension_for_transformers.transformers.utils import WeightOnlyQuantConfig
73-
74-
7570
class INCModel(OptimizedModel):
7671
auto_model_class = AutoModel
7772
export_feature = "feature-extraction"
@@ -142,15 +137,16 @@ def _from_pretrained(
142137
msg = None
143138
if is_intel_extension_for_transformers_available():
144139
try:
145-
quantization_config = WeightOnlyQuantConfig.from_pretrained(model_id)
146-
algorithm = getattr(quantization_config, "algorithm", None)
147-
if algorithm is not None and quantization_config.algorithm.lower() in {
148-
"rtn",
149-
"gptq",
150-
"awq",
151-
"autoaround",
152-
}:
153-
return ITREX_WOQ_MODEL.from_pretrained(
140+
quantization_config = PretrainedConfig.from_pretrained(model_save_dir / "quantize_config.json")
141+
algorithm = getattr(quantization_config, "quant_method", None)
142+
if algorithm in {"rtn", "gptq", "awq", "autoaround"}:
143+
from intel_extension_for_transformers.transformers.modeling.modeling_auto import (
144+
_BaseQBitsAutoModelClass,
145+
)
146+
147+
_BaseQBitsAutoModelClass.ORIG_MODEL = cls.auto_model_class
148+
149+
return _BaseQBitsAutoModelClass.from_pretrained(
154150
pretrained_model_name_or_path=model_id,
155151
use_auth_token=use_auth_token,
156152
revision=revision,

optimum/intel/neural_compressor/modeling_decoder.py

-27
This file was deleted.

0 commit comments

Comments
 (0)