Skip to content

Commit 5f4378e

Browse files
authored
[QAT Lora 3/N] Introduced FQ with absorbable LoRA as a new weight compression format (#3322)
### Changes as stated in the title ### Reason for changes Part of the QAT method with absorbable LoRA adapters to achieve better accuracy for 4bit models. ![image](https://github.com/user-attachments/assets/b3019e9a-b67a-4aec-9781-e6f94905938b) Tuning pipeline ![image](https://github.com/user-attachments/assets/50738363-8c5a-492a-915d-1ee88982599b) ### Related tickets 154907 ### Tests - [x] tests/torch/ptq/test_fq_lora.py microsoft/phi3.5-mini-instruct, seqlen=4096 | Method | Main <br> Precision | Emb/Head<br> Precision | Group <br> Size |wikitext,<br> word_ppl | |---------------------------------------|-------------------------|----------------------------|--------------------|----------------------------| | Original model (OV) | BF16 | BF16 | | 9.98 | | Original model (Torch) | BF16 | BF16 | | 10.00 | | [QAT] Mergeable LoRA | INT4_ASYM | INT8_ASYM | 64 | 10.47 | | [PTQ] AWQ + Scale Estimation + GPTQ | INT4_ASYM | INT8_ASYM | 64 | 10.71 | | [QAT] Mergeable LoRA | INT4_SYM | INT8_SYM | 512 | 10.86 | | [PTQ] AWQ + Scale Estimation + GPTQ | INT4_SYM | INT8_SYM | 512 | 11.32 | HuggingFaceTB/SmolLM-1.7B-Instruct, seqlen-2048 | Method | Main <br> Precision | Emb/Head<br> Precision | Group <br> Size | wikitext,<br> word_ppl | |---------------------------------------|-------------------------|----------------------------|---------------------|----------------------------| | Original model | BF16 | BF16 | | 19.11 | | [QAT] Mergeable LoRA | INT4_ASYM | INT8_ASYM | 64 | 19.31 | | [PTQ] AWQ + Scale Estimation + GPTQ | INT4_ASYM | INT8_ASYM | 64 | 19.68 |
1 parent 90f3727 commit 5f4378e

27 files changed

+809
-155
lines changed

.ci/cspell_dict.txt

+1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ logit
220220
loglikelihoods
221221
lstmsequence
222222
lstsq
223+
lspec
223224
lyalyushkin
224225
mapillary
225226
maskrcnn

nncf/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from nncf.errors import UnsupportedVersionError as UnsupportedVersionError
3535
from nncf.errors import ValidationError as ValidationError
3636
from nncf.parameters import BackupMode as BackupMode
37+
from nncf.parameters import CompressionFormat as CompressionFormat
3738
from nncf.parameters import CompressWeightsMode as CompressWeightsMode
3839
from nncf.parameters import DropType as DropType
3940
from nncf.parameters import ModelType as ModelType

nncf/common/quantization/structs.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,22 @@
2727
@api()
2828
class QuantizationScheme(StrEnum):
2929
"""
30-
Basic enumeration for quantization scheme specification.
31-
32-
:param SYMMETRIC:
33-
:param ASYMMETRIC:
30+
Enumeration for specifying quantization schemes.
31+
32+
:param SYMMETRIC: Symmetric quantization where the range is defined by a single parameter - scale.
33+
This range can include both negative and positive values if signed, or only positive values if unsigned.
34+
:param ASYMMETRIC: Asymmetric quantization where the range is defined by two parameters - input_low and input_high,
35+
representing the lower and upper boundaries of the range, respectively.
36+
:param SYMMETRIC_LORA: Symmetric quantization with Low-Rank Adapters (LoRA), involving the sum of weights and
37+
the multiplication of low-rank adapters.
38+
:param ASYMMETRIC_LORA: Asymmetric quantization with Low-Rank Adapters (LoRA), involving the sum of weights and
39+
the multiplication of low-rank adapters.
3440
"""
3541

3642
SYMMETRIC = "symmetric"
3743
ASYMMETRIC = "asymmetric"
44+
SYMMETRIC_LORA = "symmetric_lora"
45+
ASYMMETRIC_LORA = "asymmetric_lora"
3846

3947

4048
class QuantizerConfig:

nncf/experimental/torch/fx/quantization/quantize_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3232
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
3333
from nncf.parameters import BackupMode
34+
from nncf.parameters import CompressionFormat
3435
from nncf.parameters import CompressWeightsMode
3536
from nncf.parameters import ModelType
3637
from nncf.parameters import QuantizationMode
@@ -131,6 +132,7 @@ def compress_weights_impl(
131132
gptq: bool,
132133
lora_correction: bool,
133134
backup_mode: BackupMode,
135+
compression_format: CompressionFormat,
134136
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
135137
) -> torch.fx.GraphModule:
136138
"""
@@ -149,6 +151,7 @@ def compress_weights_impl(
149151
gptq,
150152
lora_correction,
151153
backup_mode,
154+
compression_format,
152155
advanced_parameters,
153156
)
154157
graph = NNCFGraphFactory.create(model)

nncf/openvino/quantization/quantize_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from nncf.openvino.quantization.quantize_ifmodel import apply_algorithm_if_bodies
3333
from nncf.openvino.rt_info import dump_parameters
3434
from nncf.parameters import BackupMode
35+
from nncf.parameters import CompressionFormat
3536
from nncf.parameters import CompressWeightsMode
3637
from nncf.parameters import DropType
3738
from nncf.parameters import ModelType
@@ -376,6 +377,7 @@ def compress_weights_impl(
376377
gptq: bool,
377378
lora_correction: bool,
378379
backup_mode: BackupMode,
380+
compression_format: CompressionFormat,
379381
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
380382
) -> ov.Model:
381383
"""
@@ -396,6 +398,7 @@ def compress_weights_impl(
396398
gptq,
397399
lora_correction,
398400
backup_mode,
401+
compression_format,
399402
advanced_parameters,
400403
)
401404

nncf/parameters.py

+23
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,29 @@ class CompressWeightsMode(StrEnum):
9696
E2M1 = "e2m1"
9797

9898

99+
@api(canonical_alias="nncf.CompressionFormat")
100+
class CompressionFormat(StrEnum):
101+
"""
102+
Describes the format in which the model is saved after weight compression.
103+
104+
:param DQ: Represents the 'dequantize' format, where weights are stored in low-bit precision,
105+
and a dequantization subgraph is added to the model. This is the default format for post-training weight
106+
compression methods.
107+
:param FQ: Represents the 'fake_quantize' format, where quantization is simulated by applying
108+
quantization and dequantization operations. Weights remain in the same precision. This format is
109+
suitable for quantization-aware training (QAT).
110+
:param FQ_LORA: Represents the 'fake_quantize_with_lora' format, which combines fake quantization
111+
with absorbable low-rank adapters (LoRA). Quantization is applied to the sum of weights and
112+
the multiplication of adapters. This makes quantization-aware training (QAT) more efficient in terms of
113+
accuracy, as adapters can also be tuned and remain computationally affordable during training due to their
114+
small dimensions.
115+
"""
116+
117+
DQ = "dequantize"
118+
FQ = "fake_quantize"
119+
FQ_LORA = "fake_quantize_with_lora"
120+
121+
99122
@api(canonical_alias="nncf.BackupMode")
100123
class BackupMode(StrEnum):
101124
"""

nncf/quantization/advanced_parameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ class AdvancedCompressionParameters:
384384
# Advanced Lora Correction algorithm parameters
385385
lora_correction_params: AdvancedLoraCorrectionParameters = field(default_factory=AdvancedLoraCorrectionParameters)
386386

387+
# rank of lora adapters for FQ_LORA format. Defaults to 256.
388+
lora_adapter_rank: int = 256
389+
387390

388391
@api()
389392
@dataclass

nncf/quantization/algorithms/weight_compression/algorithm.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from nncf.common.utils.helpers import create_table
3232
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
3333
from nncf.parameters import BackupMode
34+
from nncf.parameters import CompressionFormat
3435
from nncf.parameters import CompressWeightsMode
3536
from nncf.parameters import SensitivityMetric
3637
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
@@ -122,6 +123,7 @@ def check_user_compression_configuration(
122123
ignored_scope: Optional[IgnoredScope],
123124
sensitivity_metric: Optional[SensitivityMetric],
124125
backup_mode: Optional[BackupMode],
126+
compression_format: Optional[CompressionFormat],
125127
advanced_parameters: Optional[AdvancedCompressionParameters],
126128
) -> None:
127129
"""
@@ -172,6 +174,10 @@ def check_user_compression_configuration(
172174
requires a dataset, but it's not provided."
173175
raise nncf.ValidationError(msg)
174176

177+
if lora_correction and compression_format in [CompressionFormat.FQ, CompressionFormat.FQ_LORA]:
178+
msg = "LoRA Correction algorithm is not compatible with FQ and FQ_LORA compression formats."
179+
raise nncf.ValidationError(msg)
180+
175181

176182
class WeightCompression(Algorithm):
177183
"""
@@ -195,6 +201,7 @@ def __init__(
195201
gptq: bool,
196202
lora_correction: bool,
197203
backup_mode: BackupMode = BackupMode.INT8_ASYM,
204+
compression_format: CompressionFormat = CompressionFormat.DQ,
198205
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
199206
):
200207
"""
@@ -233,6 +240,7 @@ def __init__(
233240
In this mode, weights are retained in their original precision without any quantization.
234241
INT8_SYM stands for 8-bit integer symmetric quantization without zero point.
235242
INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point.
243+
:param compression_format: Describes the format in which the model is saved after weight compression.
236244
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
237245
"""
238246
super().__init__()
@@ -251,6 +259,7 @@ def __init__(
251259
self._gptq = gptq
252260
self._lora_correction = lora_correction
253261
self._backup_mode = backup_mode
262+
self._compression_format = compression_format
254263
self._advanced_parameters = (
255264
advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters()
256265
)
@@ -646,6 +655,7 @@ def apply(
646655
scales,
647656
zero_points,
648657
lora_correction_algo,
658+
self._compression_format,
649659
)
650660

651661
self._backend_entity.dump_parameters(
@@ -662,6 +672,7 @@ def apply(
662672
"gptq": self._gptq,
663673
"lora_correction": self._lora_correction,
664674
"backup_mode": self._backup_mode.value,
675+
"compression_format": self._compression_format.value,
665676
"advanced_parameters": convert_to_dict_recursively(self._advanced_parameters),
666677
},
667678
algo_name="weight_compression",

nncf/quantization/algorithms/weight_compression/backend.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
2525
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
2626
from nncf.experimental.common.tensor_statistics.statistics import HessianTensorStatistic
27+
from nncf.parameters import CompressionFormat
28+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
2729
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
30+
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
2831
from nncf.tensor import Tensor
2932
from nncf.tensor import TensorDataType
3033

@@ -147,15 +150,23 @@ def transform_model(
147150
weight_compression_parameters: Iterable[WeightCompressionParameters],
148151
precomputed_scales: Dict[str, Tensor] = None,
149152
precomputed_zero_points: Dict[str, Tensor] = None,
153+
lora_correction_algo: Optional[LoraCorrectionAlgorithm] = None,
154+
compression_format: CompressionFormat = CompressionFormat.DQ,
155+
advanced_parameters: AdvancedCompressionParameters = AdvancedCompressionParameters(),
150156
) -> TModel:
151157
"""
152158
Applies weight compression transformations to the model.
153159
154160
:param model: Model in which the weights will be compressed according to the weight compression description.
155161
:param graph: The graph associated with the model.
156-
:param weight_compression_parameters: List of weight compression parameters.
157-
:param precomputed_scales: Precomputed scales for weights compression.
158-
:param precomputed_zero_points: Precomputed zero points for weights compression.
162+
:param weight_compression_parameters: An iterable of weight compression parameters.
163+
:param precomputed_scales: Precomputed scales for weight compression.
164+
:param precomputed_zero_points: Precomputed zero points for weight compression.
165+
:param lora_correction_algo: An optional algorithm to reduce quantization noise after weight compression by
166+
using low-rank adapters. This algorithm not only overrides weights with their quantized counterparts but
167+
also expands the model's execution graph following the Low-Rank Adaptation (LoRA) concept.
168+
:param compression_format: The format in which the model is saved after weight compression.
169+
:param compression_format_params: Describes advanced parameters of compression formats.
159170
:return: The transformed model.
160171
"""
161172

nncf/quantization/algorithms/weight_compression/openvino_backend.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@
4646
from nncf.openvino.statistics.collectors import OVMeanReducer
4747
from nncf.openvino.statistics.collectors import OVMeanVarianceReducer
4848
from nncf.openvino.statistics.collectors import OVShapeReducer
49+
from nncf.parameters import CompressionFormat
4950
from nncf.parameters import CompressWeightsMode
51+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
5052
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
5153
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
5254
from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend
@@ -283,6 +285,8 @@ def transform_model(
283285
precomputed_scales: Dict[str, Tensor] = None,
284286
precomputed_zero_points: Dict[str, Tensor] = None,
285287
lora_correction_algo: LoraCorrectionAlgorithm = None,
288+
compression_format: CompressionFormat = CompressionFormat.DQ,
289+
advanced_parameters: AdvancedCompressionParameters = AdvancedCompressionParameters(),
286290
) -> ov.Model:
287291
for wc_params in weight_compression_parameters:
288292
const_attributes = wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]

nncf/quantization/algorithms/weight_compression/scale_estimation.py

-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def calculate_quantization_params(
232232
X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size)
233233
best_diffs = None
234234
result_scale = None
235-
236235
fp_outs = fns.matmul(fns.transpose(original_weight, (1, 0, 2)), X)
237236
q_outs = fns.matmul(fns.transpose(q_weights, (1, 0, 2)), X)
238237

0 commit comments

Comments
 (0)