You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: docs/3x/PT_WeightOnlyQuant.md
+26-6
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,7 @@
1
1
2
2
PyTorch Weight Only Quantization
3
3
===============
4
+
4
5
-[Introduction](#introduction)
5
6
-[Supported Matrix](#supported-matrix)
6
7
-[Usage](#usage)
@@ -28,7 +29,6 @@ Besides, as mentioned in many papers[1][2], activation quantization is the main
28
29
29
30
Theoretically, round-to-nearest (RTN) is the most straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy.
30
31
31
-
32
32
## Supported Matrix
33
33
34
34
| Algorithms/Backend | PyTorch eager mode |
@@ -58,25 +58,29 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz
58
58
WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./PyTorch.md#quantization-apis).
-*group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters.
75
78
- 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.
76
-
- Only RTN and GPTQ support double quant.
77
-
79
+
-*quant_lm_head* defaults to False. This means that, except for transformer blocks, the last layer in transformer models will not be quantized by default. The last layer may be named "lm_head", "output_layer" or "embed_out".
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. |
109
115
> **Note:**`model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.
116
+
110
117
```python
111
118
# Quantization code
112
119
from neural_compressor.torch.quantization import prepare, convert, GPTQConfig
| quant_zero (bool) | Whether to quantize zero point | True |
196
210
| quant_scale: (bool) | Whether to quantize scale: point | False |
197
211
| scale_quant_group_size (int) | The group size for quantizing scale | 128 |
198
-
| skip_lm_head (bool) | Whether to skip for quantizing lm_head | True |
212
+
199
213
```python
200
214
# Quantization code
201
215
from neural_compressor.torch.quantization import prepare, convert, HQQConfig
@@ -205,10 +219,13 @@ model = prepare(model, quant_config)
205
219
run_fn(model) # calibration
206
220
model = convert(model)
207
221
```
222
+
208
223
### Specify Quantization Rules
224
+
209
225
Intel(R) Neural Compressor support specify quantization rules by operator name or operator type. Users can set `local` in dict or use `set_local` method of config class to achieve the above purpose.
210
226
211
227
1. Example of setting `local` from a dict
228
+
212
229
```python
213
230
quant_config = {
214
231
"rtn": {
@@ -226,15 +243,19 @@ quant_config = {
226
243
}
227
244
}
228
245
```
246
+
229
247
2. Example of using `set_local`
248
+
230
249
```python
231
250
quant_config = RTNConfig()
232
251
lm_head_config = RTNConfig(dtype="fp32")
233
252
quant_config.set_local("lm_head", lm_head_config)
234
253
```
235
254
236
255
### Saving and Loading
256
+
237
257
The saved_results folder contains two files: quantized_model.pt and qconfig.json, and the generated model is a quantized model. The quantitative model will include WeightOnlyLinear. To support low memory inference, Intel(R) Neural Compressor implemented WeightOnlyLinear, a torch.nn.Module, to compress the fake quantized fp32 model. Since torch does not provide flexible data type storage, WeightOnlyLinear combines low bits data into a long date type, such as torch.int8 and torch.int32. Low bits data includes weights and zero points. When using WeightOnlyLinear for inference, it will restore the compressed data to float32 and run torch linear function.
258
+
238
259
```python
239
260
# Quantization code
240
261
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
@@ -255,7 +276,6 @@ loaded_model = load(
255
276
) # Please note that the original_model parameter passes the original model.
256
277
```
257
278
258
-
259
279
## Examples
260
280
261
281
Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only) on how to quantize a model with WeightOnlyQuant.
@@ -272,6 +292,6 @@ Users can also refer to [examples](https://github.com/intel/neural-compressor/bl
272
292
273
293
[5]. Cheng, Wenhua, et al. "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs" arXiv preprint arXiv:2309.05516 (2023).
274
294
275
-
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: https://mobiusml.github.io/hqq_blog/ (2023).
295
+
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: <https://mobiusml.github.io/hqq_blog/> (2023).
276
296
277
297
[7]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).
Copy file name to clipboardexpand all lines: examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py
0 commit comments