Skip to content

Commit 0dc592f

Browse files
authored
[Torch] AWQ support for WeightCompression (#3279)
### Changes Add support of AWQ (`AWQAlgoAlgoBackend`) for Torch: - Put common awq tests in test template - Add torch test case for WC conformance ### Reason for changes To have better compression results with AWQ support for Torch ### Related tickets 160668 ### Tests WC run - https://github.com/openvinotoolkit/nncf/actions/runs/13368073134
1 parent 3823804 commit 0dc592f

File tree

12 files changed

+401
-160
lines changed

12 files changed

+401
-160
lines changed

nncf/quantization/algorithms/weight_compression/algorithm.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,16 @@ def __init__(
257257
criterion_cls = MIXED_PRECISION_CRITERIA.get(self._sensitivity_metric)
258258
self._mixed_precision_algo = criterion_cls(primary_config, self._ratio, self._subset_size)
259259
self._statistics_path = self._advanced_parameters.statistics_path
260+
261+
if self._awq:
262+
awq_params = self._advanced_parameters.awq_params
263+
self.awq_algo = AWQ(
264+
awq_params.subset_size,
265+
awq_params.percent_to_apply,
266+
awq_params.alpha_min,
267+
awq_params.alpha_max,
268+
awq_params.steps,
269+
)
260270
if self._gptq:
261271
gptq_params = self._advanced_parameters.gptq_params
262272
self._gptq_algo = GPTQ(
@@ -586,26 +596,12 @@ def apply(
586596
nodes_to_compress = list(
587597
filter(lambda node: node.node_name not in nodes_names_to_exclude, nodes_to_compress)
588598
)
589-
590599
if self._awq:
591-
awq_params = self._advanced_parameters.awq_params
592-
awq_algo = AWQ(
593-
model,
594-
self._backend_entity.name_to_node_mapping,
595-
all_weight_params,
596-
nodes_to_compress,
597-
statistics,
598-
awq_params.subset_size,
599-
awq_params.percent_to_apply,
600-
awq_params.alpha_min,
601-
awq_params.alpha_max,
602-
awq_params.steps,
603-
)
604-
awq_algo.apply(model, graph)
600+
self.awq_algo.apply(model, graph, all_weight_params, nodes_to_compress, statistics, self._backend_entity)
605601
# After applying AWQ we need to update statistics since AWQ alters the activations
606-
statistics = awq_algo.update_statistics(statistics)
602+
statistics = self.awq_algo.update_statistics(statistics)
607603
# del is used to prematurely mark non-necessary data as free for garbage collection
608-
del awq_algo
604+
del self.awq_algo
609605

610606
scales = {}
611607
zero_points = {}

nncf/quantization/algorithms/weight_compression/awq.py

+32-36
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
from copy import deepcopy
1313
from dataclasses import dataclass
14-
from typing import Any, Dict, List, Optional, TypeVar
14+
from typing import Dict, List, Optional, TypeVar
1515

1616
import nncf
17-
from nncf import Dataset
1817
from nncf import nncf_logger
1918
from nncf.common.factory import ModelTransformerFactory
2019
from nncf.common.graph.graph import NNCFGraph
@@ -29,6 +28,7 @@
2928
from nncf.parameters import CompressWeightsMode
3029
from nncf.quantization.algorithms.algorithm import Algorithm
3130
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
31+
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
3232
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
3333
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
3434
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
@@ -61,34 +61,20 @@ class AWQ(Algorithm):
6161

6262
def __init__(
6363
self,
64-
model: TModel,
65-
name_to_node_mapping: Dict[str, Any],
66-
all_weight_params: List[WeightCompressionParameters],
67-
nodes_to_compress: List[NNCFNode],
68-
statistics: Dict[str, WCTensorStatistic],
6964
subset_size: int = 32,
70-
percent_to_apply=0.002,
71-
alpha_min=0.0,
72-
alpha_max=1.0,
73-
steps=100,
65+
percent_to_apply: float = 0.002,
66+
alpha_min: float = 0.0,
67+
alpha_max: float = 1.0,
68+
steps: int = 100,
7469
):
7570
"""
76-
:param model: Model for applying algorithm.
77-
:param name_to_node_mapping: Name to node mapping for updating node weights.
78-
:param all_weight_params: List of all weight parameters.
79-
:param nodes_to_compress: List of nodes for processing.
80-
:param statistics: Input activation statistics for each node.
8171
:param subset_size: The number of samples for AWQ.
8272
:param percent_to_apply: The percent of outliers for correction.
8373
:param alpha_min: Minimum value of smoothness parameter for grid search.
8474
:param alpha_max: Maximal value of smoothness parameter for grid search.
8575
:param steps: The number of the steps in grid search.
8676
"""
8777
super().__init__()
88-
self.name_to_node_mapping = name_to_node_mapping
89-
self._all_weight_params = all_weight_params
90-
self._nodes_to_compress = nodes_to_compress
91-
self._statistics = statistics
9278
self._subset_size = subset_size
9379
self._percent_to_apply = percent_to_apply
9480
self._alpha_min = alpha_min
@@ -98,44 +84,54 @@ def __init__(
9884
self._patterns = None
9985
self._scale_per_target_node = {}
10086

101-
self._set_backend_entity(model)
102-
10387
@property
10488
def available_backends(self) -> List[BackendType]:
105-
return [BackendType.OPENVINO]
89+
return [BackendType.OPENVINO, BackendType.TORCH]
10690

107-
def _set_backend_entity(self, model: TModel) -> None:
91+
def _set_backend_entity(
92+
self, model: TModel, wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None
93+
) -> None:
10894
"""
10995
Creates a helper class with a backed-specific logic of the algorithm.
11096
11197
:param model: Backend-specific input model.
98+
:param wc_backend_entity: Weight compression algorithm backend.
11299
"""
113100
model_backend = get_backend(model)
114101
if model_backend == BackendType.OPENVINO:
115102
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVAWQAlgoAlgoBackend
116103

117-
self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping)
118-
self._patterns = self._backend_entity.get_awq_patterns()
104+
self._backend_entity = OVAWQAlgoAlgoBackend(model, wc_backend_entity.name_to_node_mapping)
105+
elif model_backend == BackendType.TORCH:
106+
from nncf.quantization.algorithms.weight_compression.torch_backend import PTAWQAlgoAlgoBackend
107+
108+
self._backend_entity = PTAWQAlgoAlgoBackend()
109+
119110
else:
120111
msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!"
121112
raise nncf.UnsupportedBackendError(msg)
113+
self._patterns = self._backend_entity.get_awq_patterns()
122114

123115
def apply(
124116
self,
125117
model: TModel,
126118
graph: NNCFGraph,
127-
statistic_points: Optional[StatisticPointsContainer] = None,
128-
dataset: Optional[Dataset] = None,
119+
all_weight_params: List[WeightCompressionParameters],
120+
nodes_to_compress: List[NNCFNode],
121+
statistics: Dict[str, WCTensorStatistic],
122+
wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None,
129123
) -> TModel:
130124
"""
131125
Applies the algorithm to the model.
132-
133126
:param model: Model for applying algorithm.
134127
:param graph: Model graph.
135-
:param statistic_points: Statistic points with collected statistics values.
136-
:param dataset: A representative dataset for the calibration process.
128+
:param all_weight_params: List of all weight parameters.
129+
:param nodes_to_compress: List of nodes for processing.
130+
:param statistics: Input activation statistics for each node.
131+
:param wc_backend_entity: Weight compression algorithm backend.
137132
:return: A resulting model.
138133
"""
134+
self._set_backend_entity(model, wc_backend_entity)
139135
matches = []
140136

141137
inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], [])
@@ -151,7 +147,7 @@ def apply(
151147
model_transformer = ModelTransformerFactory.create(model, inplace=True)
152148

153149
awq_data = {}
154-
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(self._all_weight_params)}
150+
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)}
155151

156152
for match in matches:
157153
nncf_node = graph.get_node_by_key(match[-1])
@@ -166,11 +162,11 @@ def apply(
166162
if target_node_names[-1] not in name_mapping:
167163
continue
168164

169-
weight_params = self._all_weight_params[name_mapping[target_node_names[-1]]]
165+
weight_params = all_weight_params[name_mapping[target_node_names[-1]]]
170166

171167
if weight_params.compression_config.num_bits != 4:
172168
continue
173-
target_node = self._nodes_to_compress[name_mapping[target_node_names[-1]]]
169+
target_node = nodes_to_compress[name_mapping[target_node_names[-1]]]
174170

175171
# avoid matching different patterns for the same node
176172
if target_node.node_name in awq_data:
@@ -182,7 +178,7 @@ def apply(
182178
merge_node_names = []
183179
for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph):
184180
merge_node_names.append(weight_op_friendly_name)
185-
merge_node = self._nodes_to_compress[name_mapping[merge_node_names[-1]]]
181+
merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]]
186182
else: # pattern Act->MatMul or Act->Multiply->MatMul
187183
merge_node = nncf_node
188184

@@ -204,7 +200,7 @@ def apply(
204200

205201
config = wp.compression_config
206202

207-
s, X = process_stats(self._statistics[k], self._subset_size)
203+
s, X = process_stats(statistics[k], self._subset_size)
208204

209205
top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
210206
topk_idxs = fns.argsort(-s)[:top_k]

nncf/quantization/algorithms/weight_compression/torch_backend.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from nncf.common.graph.graph import NNCFNode
2020
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
2121
from nncf.common.graph.operator_metatypes import OperatorMetatype
22+
from nncf.common.graph.patterns import GraphPattern
2223
from nncf.common.graph.transformations.commands import TargetType
2324
from nncf.common.graph.transformations.layout import TransformationLayout
2425
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
@@ -35,6 +36,9 @@
3536
from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic
3637
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
3738
from nncf.parameters import CompressWeightsMode
39+
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
40+
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
41+
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
3842
from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend
3943
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
4044
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
@@ -44,6 +48,8 @@
4448
from nncf.tensor.definitions import TensorDataType
4549
from nncf.torch.dynamic_graph.scope import Scope
4650
from nncf.torch.graph import operator_metatypes as om
51+
from nncf.torch.graph.operator_metatypes import PTMulMetatype
52+
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
4753
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
4854
from nncf.torch.graph.transformations.commands import PTTargetPoint
4955
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
@@ -52,6 +58,7 @@
5258
from nncf.torch.model_graph_manager import get_module_by_name
5359
from nncf.torch.model_graph_manager import split_const_name
5460
from nncf.torch.model_transformer import PTModelTransformer
61+
from nncf.torch.model_transformer import update_parameter
5562
from nncf.torch.nncf_network import NNCFNetwork
5663
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
5764
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
@@ -202,12 +209,12 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
202209
def set_weight(
203210
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
204211
):
205-
pass
212+
update_parameter(node_with_weight.node_name, "weight", weight.data, model)
206213

207214
def insert_adapters(
208215
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
209216
) -> None:
210-
pass
217+
raise NotImplementedError()
211218

212219
@staticmethod
213220
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
@@ -320,6 +327,39 @@ def transform_model(
320327
return transformed_model
321328

322329

330+
class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
331+
@staticmethod
332+
def get_awq_patterns():
333+
return get_awq_patterns(
334+
PTWeightCompressionAlgoBackend.MATMUL_METATYPES,
335+
PTMulMetatype,
336+
ATOMIC_ACTIVATIONS_OPERATIONS[GraphPattern.METATYPE_ATTR],
337+
)
338+
339+
@staticmethod
340+
def scale_insertion_command(
341+
source_node: NNCFNode,
342+
next_nodes,
343+
source_output_port_id: int,
344+
scale: torch.Tensor,
345+
) -> PTSharedFnInsertionCommand:
346+
input_port_id = 0
347+
target_points = []
348+
for node in next_nodes:
349+
target_points.append(
350+
PTTargetPoint(
351+
PTWeightCompressionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[TargetType.PRE_LAYER_OPERATION],
352+
node.node_name,
353+
input_port_id=input_port_id,
354+
)
355+
)
356+
357+
sq_multiply = SQMultiply(scale.shape)
358+
sq_multiply.scale = scale
359+
scale_node_name = f"{source_node.node_name}/awq_mul"
360+
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
361+
362+
323363
class PTMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, PTWeightCompressionAlgoBackend):
324364
@staticmethod
325365
def mean_variance_statistic_collector(

nncf/quantization/quantize_model.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,7 @@ def compress_weights(
518518
msg = "Torch backend does not support NF4 and E2M1 modes for weight compression."
519519
raise nncf.ParameterNotSupportedError(msg)
520520

521-
options = {
522-
"awq": awq,
523-
"gptq": gptq,
524-
"lora_correction": lora_correction,
525-
}
521+
options = {"gptq": gptq, "lora_correction": lora_correction}
526522
unsupported_options = [name for name, value in options.items() if value is not None]
527523
if unsupported_options:
528524
msg = f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None."

nncf/torch/model_transformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Callable, Dict, List, Optional, Tuple
1515

1616
import torch
17-
from torch import Tensor
1817
from torch import nn
1918
from torch.nn.parameter import Parameter
2019

@@ -242,7 +241,7 @@ def _apply_weights_update_transformations(
242241
return model
243242

244243

245-
def update_parameter(target_node_name: str, parameter_name: str, new_value: Tensor, model: NNCFNetwork) -> None:
244+
def update_parameter(target_node_name: str, parameter_name: str, new_value: torch.Tensor, model: NNCFNetwork) -> None:
246245
"""
247246
Update parameter for target module.
248247

0 commit comments

Comments
 (0)