Skip to content

Commit 4ae2e87

Browse files
authored
support quant_lm_head arg in all WOQ configs (#1881)
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent cc763f5 commit 4ae2e87

File tree

19 files changed

+379
-181
lines changed

19 files changed

+379
-181
lines changed

docs/3x/PT_WeightOnlyQuant.md

+26-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
PyTorch Weight Only Quantization
33
===============
4+
45
- [Introduction](#introduction)
56
- [Supported Matrix](#supported-matrix)
67
- [Usage](#usage)
@@ -28,7 +29,6 @@ Besides, as mentioned in many papers[1][2], activation quantization is the main
2829

2930
Theoretically, round-to-nearest (RTN) is the most straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy.
3031

31-
3232
## Supported Matrix
3333

3434
| Algorithms/Backend | PyTorch eager mode |
@@ -58,25 +58,29 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz
5858
WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./PyTorch.md#quantization-apis).
5959

6060
#### Common arguments
61+
6162
| Config | Capability |
6263
|---|---|
6364
| dtype (str)| ['int', 'nf4', 'fp4'] |
6465
| bits (int)| [1, ..., 8] |
6566
| group_size (int)| [-1, 1, ..., $C_{in}$] |
6667
| use_sym (bool)| [True, False] |
68+
| quant_lm_head (bool)| [False, True] |
6769
| use_double_quant (bool) | [True, False] |
6870
| double_quant_dtype (str) | ['int'] |
6971
| double_quant_bits (int) | [1, ..., bits] |
7072
| double_quant_use_sym (bool) | [True, False] |
7173
| double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] |
7274

7375
Notes:
76+
7477
- *group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters.
7578
- 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.
76-
- Only RTN and GPTQ support double quant.
77-
79+
- *quant_lm_head* defaults to False. This means that, except for transformer blocks, the last layer in transformer models will not be quantized by default. The last layer may be named "lm_head", "output_layer" or "embed_out".
80+
- Only RTN and GPTQ support double quant.
7881

7982
#### RTN
83+
8084
| rtn_args | comments | default value |
8185
|----------|-------------|-------------------------------------------------------------------|
8286
| group_dim (int) | Dimension for grouping | 1 |
@@ -86,6 +90,7 @@ Notes:
8690
| model_path (str) | Model path that is used to load state_dict per layer | |
8791

8892
> **Notes:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.
93+
8994
``` python
9095
# Quantization code
9196
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
@@ -96,6 +101,7 @@ model = convert(model)
96101
```
97102

98103
#### GPTQ
104+
99105
| gptq_args | comments | default value |
100106
|----------|-------------|-------------------------------------------------------------------|
101107
| use_mse_search (bool) | Enables mean squared error (MSE) search | False
@@ -107,6 +113,7 @@ model = convert(model)
107113
| block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 |
108114
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. |
109115
> **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.
116+
110117
``` python
111118
# Quantization code
112119
from neural_compressor.torch.quantization import prepare, convert, GPTQConfig
@@ -118,6 +125,7 @@ model = convert(model)
118125
```
119126

120127
#### AutoRound
128+
121129
| autoround_args | comments | default value |
122130
|----------|-------------|-------------------------------------------------------------------|
123131
| enable_full_range (bool) | Whether to enable full range quantization | False
@@ -138,6 +146,7 @@ model = convert(model)
138146
| not_use_best_mse (bool) | Whether to use mean squared error | False |
139147
| dynamic_max_gap (int) | The dynamic maximum gap | -1 |
140148
| scale_dtype (str) | The data type of quantization scale to be used, different kernels have different choices | "float16" |
149+
141150
``` python
142151
# Quantization code
143152
from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig
@@ -149,6 +158,7 @@ model = convert(model)
149158
```
150159

151160
#### AWQ
161+
152162
| awq_args | comments | default value |
153163
|----------|-------------|-------------------------------------------------------------------|
154164
| group_dim (int) | Dimension for grouping | 1 |
@@ -159,6 +169,7 @@ model = convert(model)
159169
| use_auto_clip (bool) | Enables clip range search | True |
160170
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False. |
161171
> **Notes:** `layer-wise` is stay-tuned.
172+
162173
``` python
163174
# Quantization code
164175
from neural_compressor.torch.quantization import prepare, convert, AWQConfig
@@ -170,6 +181,7 @@ model = convert(model)
170181
```
171182

172183
#### TEQ
184+
173185
| teq_args | comments | default value |
174186
|----------|-------------|-------------------------------------------------------------------|
175187
| group_dim (int) | Dimension for grouping | 1 |
@@ -179,6 +191,7 @@ model = convert(model)
179191
| use_double_quant (bool) | Enables double quantization | False |
180192
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False |
181193
> **Notes:** `layer-wise` is stay-tuned.
194+
182195
``` python
183196
# Quantization code
184197
from neural_compressor.torch.quantization import prepare, convert, TEQConfig
@@ -190,12 +203,13 @@ model = convert(model)
190203
```
191204

192205
#### HQQ
206+
193207
| hqq_args | comments | default value |
194208
|----------|-------------|-------------------------------------------------------------------|
195209
| quant_zero (bool) | Whether to quantize zero point | True |
196210
| quant_scale: (bool) | Whether to quantize scale: point | False |
197211
| scale_quant_group_size (int) | The group size for quantizing scale | 128 |
198-
| skip_lm_head (bool) | Whether to skip for quantizing lm_head | True |
212+
199213
``` python
200214
# Quantization code
201215
from neural_compressor.torch.quantization import prepare, convert, HQQConfig
@@ -205,10 +219,13 @@ model = prepare(model, quant_config)
205219
run_fn(model) # calibration
206220
model = convert(model)
207221
```
222+
208223
### Specify Quantization Rules
224+
209225
Intel(R) Neural Compressor support specify quantization rules by operator name or operator type. Users can set `local` in dict or use `set_local` method of config class to achieve the above purpose.
210226

211227
1. Example of setting `local` from a dict
228+
212229
```python
213230
quant_config = {
214231
"rtn": {
@@ -226,15 +243,19 @@ quant_config = {
226243
}
227244
}
228245
```
246+
229247
2. Example of using `set_local`
248+
230249
```python
231250
quant_config = RTNConfig()
232251
lm_head_config = RTNConfig(dtype="fp32")
233252
quant_config.set_local("lm_head", lm_head_config)
234253
```
235254

236255
### Saving and Loading
256+
237257
The saved_results folder contains two files: quantized_model.pt and qconfig.json, and the generated model is a quantized model. The quantitative model will include WeightOnlyLinear. To support low memory inference, Intel(R) Neural Compressor implemented WeightOnlyLinear, a torch.nn.Module, to compress the fake quantized fp32 model. Since torch does not provide flexible data type storage, WeightOnlyLinear combines low bits data into a long date type, such as torch.int8 and torch.int32. Low bits data includes weights and zero points. When using WeightOnlyLinear for inference, it will restore the compressed data to float32 and run torch linear function.
258+
238259
```python
239260
# Quantization code
240261
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
@@ -255,7 +276,6 @@ loaded_model = load(
255276
) # Please note that the original_model parameter passes the original model.
256277
```
257278

258-
259279
## Examples
260280

261281
Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only) on how to quantize a model with WeightOnlyQuant.
@@ -272,6 +292,6 @@ Users can also refer to [examples](https://github.com/intel/neural-compressor/bl
272292

273293
[5]. Cheng, Wenhua, et al. "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs" arXiv preprint arXiv:2309.05516 (2023).
274294

275-
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: https://mobiusml.github.io/hqq_blog/ (2023).
295+
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: <https://mobiusml.github.io/hqq_blog/> (2023).
276296

277297
[7]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).

docs/3x/PyTorch.md

+21
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,24 @@ def load(output_dir="./saved_results", model=None):
223223
</tr>
224224
</tbody>
225225
</table>
226+
227+
2. How to set different configuration for specific op_name or op_type?
228+
> INC extends a `set_local` method based on the global configuration object to set custom configuration.
229+
230+
```python
231+
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
232+
"""Set custom configuration based on the global configuration object.
233+
234+
Args:
235+
operator_name_or_list (Union[List, str, Callable]): specific operator
236+
config (BaseConfig): specific configuration
237+
"""
238+
```
239+
240+
> Demo:
241+
242+
```python
243+
quant_config = RTNConfig() # Initialize global configuration with default bits=4
244+
quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) # For layers with "mlp" in their names, set bits=8
245+
quant_config.set_local("Conv1d", RTNConfig(dtype="fp32")) # For Conv1d layers, do not quantize them.
246+
```

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,12 @@ def get_user_model():
272272
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
273273
for batch in tqdm(dataloader_for_calibration):
274274
batch = move_input_to_device(batch, device=None)
275-
try:
276-
if isinstance(batch, tuple) or isinstance(batch, list):
277-
model(batch[0])
278-
elif isinstance(batch, dict):
279-
model(**batch)
280-
else:
281-
model(batch)
282-
except ValueError:
283-
pass
275+
if isinstance(batch, tuple) or isinstance(batch, list):
276+
model(batch[0])
277+
elif isinstance(batch, dict):
278+
model(**batch)
279+
else:
280+
model(batch)
284281
return
285282
if args.double_quant_type is not None:
286283
double_quant_config_dict.update(

neural_compressor/common/base_config.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,25 @@ def local_config(self):
198198
def local_config(self, config):
199199
self._local_config = config
200200

201-
def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
202-
if operator_name in self.local_config:
203-
logger.warning("The configuration for %s has already been set, update it.", operator_name)
204-
self.local_config[operator_name] = config
201+
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
202+
"""Set custom configuration based on the global configuration object.
203+
204+
Args:
205+
operator_name_or_list (Union[List, str, Callable]): specific operator
206+
config (BaseConfig): specific configuration
207+
208+
Returns:
209+
Updated Config
210+
"""
211+
if isinstance(operator_name_or_list, list):
212+
for operator_name in operator_name_or_list:
213+
if operator_name in self.local_config:
214+
logger.warning("The configuration for %s has already been set, update it.", operator_name)
215+
self.local_config[operator_name] = config
216+
else:
217+
if operator_name_or_list in self.local_config:
218+
logger.warning("The configuration for %s has already been set, update it.", operator_name)
219+
self.local_config[operator_name_or_list] = config
205220
return self
206221

207222
def to_dict(self):

neural_compressor/torch/algorithms/weight_only/gptq.py

+13
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,18 @@ def forward(layer, *args, **kwargs):
345345
self.gptq_related_blocks["transformers"][0].forward = partial(
346346
forward, self.gptq_related_blocks["transformers"][0]
347347
)
348+
# Step 3: replace model_forward to avoid ValueError
349+
self.orig_model_forward_cache = self.model.forward
350+
model_forward_cache = self.model.forward
351+
352+
def model_forward(model, *args, **kwargs):
353+
nonlocal model_forward_cache
354+
try:
355+
model_forward_cache(*args, **kwargs)
356+
except ValueError:
357+
pass
358+
359+
self.model.forward = partial(model_forward, self.model)
348360

349361
@torch.no_grad()
350362
def remove_prepare_for_calibration(self):
@@ -359,6 +371,7 @@ def remove_prepare_for_calibration(self):
359371
logger.info("Done.")
360372

361373
# Step 4: restore original forward function, relocate layers back to cpu.
374+
self.model.forward = self.orig_model_forward_cache
362375
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
363376
if not self.use_layer_wise: # pragma: no cover
364377
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()

neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def save(self, model, path):
119119
pass
120120

121121
def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
122-
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`
122+
# TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm.
123+
# * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py
123124
nbits = config.bits
124125
group_size = config.group_size
125126
quant_zero = config.quant_zero
@@ -146,9 +147,6 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
146147
def _parse_hqq_configs_mapping(self, configs_mapping):
147148
qconfig_mapping = {}
148149
for (op_name, op_type), quant_config in configs_mapping.items():
149-
if quant_config.skip_lm_head and "lm_head" in op_name:
150-
logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name)
151-
continue
152150
if quant_config is not None and quant_config.dtype == "fp32":
153151
logger.warning("Fallback %s.", op_name)
154152
continue

neural_compressor/torch/algorithms/weight_only/rtn.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,20 @@
1919
# limitations under the License.
2020

2121

22+
import copy
2223
from collections import OrderedDict
2324

2425
import torch
2526

2627
from neural_compressor.torch.algorithms import Quantizer
27-
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
28+
from neural_compressor.torch.utils import (
29+
get_accelerator,
30+
get_attr,
31+
is_transformers_imported,
32+
logger,
33+
set_attr,
34+
set_module,
35+
)
2836

2937
from .utility import cast_fp8, quant_tensor, search_clip
3038

@@ -64,6 +72,7 @@ def convert(
6472
quantile=1.0,
6573
use_full_range=False,
6674
use_mse_search=False,
75+
quant_lm_head=False,
6776
*args,
6877
**kwargs,
6978
):
@@ -80,8 +89,10 @@ def convert(
8089
quantile (float, optional): percentile of clip. Defaults to 1.0.
8190
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
8291
Defaults to False.
83-
use_mse_search (bool, optional): Whether search clip range.
92+
use_mse_search (bool, optional): Whether to search clip range.
8493
Defaults to True.
94+
quant_lm_head (bool, optional): Whether to quantize the lm_head layer.
95+
Defaults to False.
8596
8697
Returns:
8798
model: fake quantized torch module
@@ -93,6 +104,12 @@ def convert(
93104
# TODO: refine it later, Put module on device one by one instead of the whole model
94105
model.to(device)
95106

107+
# for transformers model. If lm_head is tied from embedding, we deepcopy it.
108+
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
109+
for key in model._tied_weights_keys:
110+
weight = get_attr(model, key)
111+
set_attr(model, key, copy.deepcopy(weight))
112+
96113
assert isinstance(model, torch.nn.Module), "only support torch module"
97114
if is_transformers_imported():
98115
supported_layers = (torch.nn.Linear, transformers.Conv1D)

neural_compressor/torch/quantization/algorithm_entry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def rtn_entry(
9292
}
9393

9494
quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
95-
model = quantizer.execute(model, mode=mode)
95+
model = quantizer.execute(model, mode=mode, quant_lm_head=quant_config.quant_lm_head)
9696
model.qconfig = configs_mapping
9797
model.save = MethodType(save, model)
9898
postprocess_model(model, mode, quantizer)

0 commit comments

Comments
 (0)