Skip to content

Commit 30ee587

Browse files
committed
[ONNX]: Add support for data-free Weight Compression Algorithm (#3273)
1 parent 5f4378e commit 30ee587

File tree

9 files changed

+899
-1
lines changed

9 files changed

+899
-1
lines changed

nncf/onnx/graph/layer_attributes.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Dict, List, Optional
13+
14+
from nncf.common.graph.layer_attributes import BaseLayerAttributes
15+
16+
17+
class ONNXLayerAttributes(BaseLayerAttributes):
18+
"""
19+
This class stores additional information about nodes that needs to be processed during compression.
20+
"""
21+
22+
def __init__(
23+
self,
24+
constant_attributes: Dict[int, Any],
25+
layer_attributes: Optional[BaseLayerAttributes] = None,
26+
inputs_attributes: Optional[Dict[Any, Any]] = None,
27+
):
28+
"""
29+
:param constant_attributes: Map of weights port ID to corresponding const attributes.
30+
:param layer_attributes: Map of weights port ID to corresponding common layer attributes.
31+
:param inputs_attributes: Activation attributes.
32+
"""
33+
self._constant_attributes = constant_attributes
34+
self._layer_attributes = layer_attributes
35+
self._inputs_attributes = inputs_attributes
36+
37+
@property
38+
def constant_attributes(self) -> Dict[int, Any]:
39+
return self._constant_attributes
40+
41+
@property
42+
def layer_attributes(self) -> Optional[BaseLayerAttributes]:
43+
return self._layer_attributes
44+
45+
@property
46+
def input_attributes(self) -> Optional[Dict[Any, Any]]:
47+
return self._inputs_attributes
48+
49+
def get_const_port_ids(self) -> List[int]:
50+
"""
51+
Returns indices of input ports corresponding to the constant nodes.
52+
53+
:returns: List of input port indices with constants.
54+
"""
55+
if self._constant_attributes is not None:
56+
return list(self._constant_attributes.keys())
57+
return []

nncf/onnx/graph/layout.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from enum import Enum
13+
from typing import Tuple
14+
15+
from nncf.common.graph.graph import NNCFNode
16+
from nncf.onnx.graph.layer_attributes import ONNXLayerAttributes
17+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype
18+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype
19+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGroupConvolutionMetatype
20+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
21+
22+
23+
class ONNXLayoutElem(Enum):
24+
"""
25+
Layout elements descriptor for convolutional and linear onnx layers:
26+
C_IN: Input channels dimension.
27+
C_OUT: Output channels dimension.
28+
SPATIAL: Spatial dimension.
29+
GROUPS: Groups dimension.
30+
"""
31+
32+
C_IN = "channels_in"
33+
C_OUT = "channels_out"
34+
SPATIAL = "spatial"
35+
GROUPS = "groups"
36+
37+
38+
_CONV_BASE_CONST_LAYOUT = {
39+
ONNXConvolutionMetatype: (ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
40+
ONNXDepthwiseConvolutionMetatype: (ONNXLayoutElem.GROUPS, ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
41+
ONNXGroupConvolutionMetatype: (ONNXLayoutElem.GROUPS, ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
42+
}
43+
44+
45+
def get_conv_weights_layout_from_node(node: NNCFNode) -> Tuple[ONNXLayoutElem]:
46+
"""
47+
Calculates weights layout for a target convolution node.
48+
49+
:param node: Target convolution node.
50+
:return: Target convolution Node weights layout.
51+
"""
52+
layer_attributes = node.layer_attributes
53+
port_id = _get_constant_port_id_from_layer_attributes(layer_attributes)
54+
return get_conv_weights_layout(
55+
ONNX_metatype=node.metatype, weights_shape=layer_attributes.constant_attributes[port_id]["shape"]
56+
)
57+
58+
59+
def get_linear_weights_layout_from_node(node: NNCFNode) -> Tuple[ONNXLayoutElem]:
60+
"""
61+
Calculates weights layout for a target linear node.
62+
63+
:param node: Target linear node.
64+
:return: Target linear Node weight layout.
65+
"""
66+
layer_attributes = node.layer_attributes
67+
port_id = _get_constant_port_id_from_layer_attributes(layer_attributes)
68+
constant_layer_attrs = layer_attributes.constant_attributes[port_id]
69+
return get_linear_input_layout(
70+
input_shape=constant_layer_attrs["shape"],
71+
transpose=constant_layer_attrs["transpose"],
72+
port_id=port_id,
73+
)
74+
75+
76+
def get_linear_activations_layout_from_node(
77+
node: NNCFNode, port_id: int, input_shape: Tuple[int]
78+
) -> Tuple[ONNXLayoutElem]:
79+
"""
80+
Calculates activations layout for a target linear node.
81+
82+
:param node: Target linear node.
83+
:param port_id: Target input port ID.
84+
:param input_shape: Shape of the input.
85+
:return: Target linear Node weight layout.
86+
"""
87+
act_layer_attrs = node.layer_attributes.input_attributes
88+
return get_linear_input_layout(
89+
input_shape=input_shape,
90+
transpose=act_layer_attrs["transpose"],
91+
port_id=port_id,
92+
)
93+
94+
95+
def get_conv_weights_layout(ONNX_metatype: ONNXOpMetatype, weights_shape: Tuple[int, ...]) -> Tuple[ONNXLayoutElem]:
96+
"""
97+
Calculates weights layout for a target convolution node.
98+
99+
:param ONNX_metatype: Target convolution node OpenVINO metatype.
100+
:param weights_shape: Shape of the target convolution node weight.
101+
:return: Target convolution node weights layout.
102+
"""
103+
base_layout = _CONV_BASE_CONST_LAYOUT[ONNX_metatype]
104+
kernel_size = weights_shape[len(base_layout) :]
105+
weights_layout = list(base_layout) + [ONNXLayoutElem.SPATIAL] * len(kernel_size)
106+
return tuple(weights_layout)
107+
108+
109+
def get_linear_input_layout(input_shape: Tuple[int, ...], transpose: bool, port_id: int) -> Tuple[ONNXLayoutElem]:
110+
"""
111+
Calculates input layout for a target linear node.
112+
113+
:param input_shape: Shape of the target linear node input.
114+
:param port_id: Port id of the target linear node input.
115+
:return: Target linear node input layout.
116+
"""
117+
input_layout = [ONNXLayoutElem.SPATIAL] * (len(input_shape) - 2)
118+
if len(input_shape) > 1:
119+
if (transpose and port_id == 0) or (not transpose and port_id == 1):
120+
input_layout += [ONNXLayoutElem.C_IN, ONNXLayoutElem.C_OUT]
121+
else:
122+
input_layout += [ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN]
123+
else:
124+
input_layout += [ONNXLayoutElem.C_IN]
125+
return tuple(input_layout)
126+
127+
128+
def _get_constant_port_id_from_layer_attributes(layer_attributes: ONNXLayerAttributes) -> int:
129+
"""
130+
Returns constant ports id for convolutional and linear ops layer attributes.
131+
132+
:param layer_attributes: Target convolutional/linear layer op layer attributes.
133+
:return: Constant port id for the target convolutional/linear model.
134+
"""
135+
port_ids = list(layer_attributes.constant_attributes.keys())
136+
assert len(port_ids) == 1
137+
return port_ids[0]

nncf/onnx/graph/metatypes/groups.py

+6
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,9 @@
164164
onnx_metatypes.ONNXROIAlignMetatype,
165165
onnx_metatypes.ONNXEmbeddingMetatype,
166166
]
167+
168+
CONV_OPERATIONS = [
169+
onnx_metatypes.ONNXConvolutionMetatype,
170+
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
171+
onnx_metatypes.ONNXGroupConvolutionMetatype,
172+
]

nncf/onnx/graph/node_utils.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, Optional, Tuple
12+
from typing import Dict, Optional, Tuple, List
1313

1414
import numpy as np
1515
import onnx
@@ -18,10 +18,20 @@
1818
from nncf.common.graph.graph import NNCFNode
1919
from nncf.common.graph.transformations.commands import TargetType
2020
from nncf.common.logging.logger import nncf_logger
21+
from nncf.onnx.graph.layout import OVLayoutElem
22+
from nncf.onnx.graph.layout import get_conv_weights_layout
23+
from nncf.onnx.graph.layout import get_conv_weights_layout_from_node
24+
from nncf.onnx.graph.layout import get_linear_activations_layout_from_node
25+
from nncf.onnx.graph.layout import get_linear_input_layout
26+
from nncf.onnx.graph.layout import get_linear_weights_layout_from_node
2127
from nncf.onnx.graph.metatypes import onnx_metatypes as om
2228
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype
29+
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
2330
from nncf.onnx.graph.onnx_helper import get_tensor_value
2431
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
32+
from nncf.onnx.graph.metatypes.groups import CONV_OPERATIONS
33+
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
34+
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
2535

2636

2737
def is_node_with_bias(node: NNCFNode) -> bool:
@@ -139,6 +149,36 @@ def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
139149
weight_channel_axis = -1 - port_id if transpose else -2 + port_id
140150
return weight_channel_axis
141151

152+
def get_weight_channel_axes(node: NNCFNode) -> List[int]:
153+
"""
154+
Returns axes numbers of the weight tensor which correspond to its channels.
155+
156+
:param node: NNCFNode with weights.
157+
:param weights_port_id: Weight port id of the target node.
158+
:return: Axes numbers of the weight tensor which correspond to its channels.
159+
"""
160+
if node.metatype not in OPERATIONS_WITH_WEIGHTS:
161+
msg = "Channel axis cannot be defined for operation without weights."
162+
raise ValueError(msg)
163+
164+
if node.metatype in CONV_OPERATIONS:
165+
weights_layout = get_conv_weights_layout_from_node(node)
166+
return [idx for idx, elem in enumerate(weights_layout) if elem in [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT]]
167+
elif node.metatype == ONNXMatMulMetatype:
168+
return get_matmul_channel_axes(node)
169+
return node.metatype.const_channel_axis
170+
171+
172+
def get_matmul_channel_axes(node: ov.Node) -> List[int]:
173+
"""
174+
Calculate channel axes for the MatMul operation.
175+
176+
:param node: The target node.
177+
:return: List of channel axes for the MatMul operation.
178+
"""
179+
weights_layout = get_linear_weights_layout_from_node(node)
180+
return [idx for idx, elem in enumerate(weights_layout) if elem in [OVLayoutElem.SPATIAL, OVLayoutElem.C_OUT]]
181+
142182

143183
def get_act_quantization_axis(node: NNCFNode, port_id: int) -> int:
144184
"""
@@ -214,3 +254,7 @@ def get_quantized_tensor_shape(
214254
if target_point.is_weight_target_point():
215255
return node.layer_attributes.weight_attrs[target_point.port_id]["shape"]
216256
return _get_activation_tensor_shape(nncf_graph, node, target_point)
257+
258+
259+
def get_const_value_as_onnx_tensor(initializer_name: str, model: onnx.ModelProto) -> np.ndarray:
260+
# TODO

nncf/onnx/quantization/quantize_model.py

+71
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,40 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from pathlib import Path
1213
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
1314

1415
import onnx
1516

1617
import nncf
18+
from nncf.common.factory import NNCFGraphFactory
19+
from nncf.common.factory import StatisticsAggregatorFactory
1720
from nncf.common.logging.logger import nncf_logger
1821
from nncf.common.quantization.structs import QuantizationPreset
1922
from nncf.data import Dataset
2023
from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
2124
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
25+
from nncf.parameters import BackupMode
26+
from nncf.parameters import CompressionFormat
27+
from nncf.parameters import CompressWeightsMode
2228
from nncf.parameters import DropType
2329
from nncf.parameters import ModelType
2430
from nncf.parameters import QuantizationMode
31+
from nncf.parameters import SensitivityMetric
2532
from nncf.parameters import TargetDevice
2633
from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters
34+
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
2735
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
2836
from nncf.quantization.advanced_parameters import QuantizationParameters
2937
from nncf.quantization.algorithms.accuracy_control.algorithm import QuantizationAccuracyRestorer
3038
from nncf.quantization.algorithms.accuracy_control.algorithm import calculate_accuracy_drop
3139
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
3240
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
41+
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
3342
from nncf.quantization.quantize_model import quantize_with_tune_hyperparams
3443
from nncf.quantization.quantize_model import warning_model_no_batchwise_support
44+
from nncf.quantization.statistics_caching import cache_weight_compression_statistics
45+
from nncf.quantization.statistics_caching import register_statistics_for_algorithm
3546
from nncf.scopes import IgnoredScope
3647

3748
TTensor = TypeVar("TTensor")
@@ -201,3 +212,63 @@ def quantize_with_accuracy_control_impl(
201212
)
202213

203214
return quantized_model
215+
216+
def compress_weights_impl(
217+
model: onnx.ModelProto,
218+
dataset: Dataset,
219+
mode: CompressWeightsMode,
220+
ratio: float,
221+
group_size: int,
222+
ignored_scope: IgnoredScope,
223+
all_layers: bool,
224+
sensitivity_metric: SensitivityMetric,
225+
awq: bool,
226+
subset_size: int,
227+
scale_estimation: bool,
228+
gptq: bool,
229+
lora_correction: bool,
230+
backup_mode: BackupMode,
231+
compression_format: CompressionFormat,
232+
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
233+
) -> onnx.ModelProto:
234+
"""
235+
Implementation of the `compress_weights()` method for the OpenVINO backend.
236+
"""
237+
graph = NNCFGraphFactory.create(model)
238+
compression_algorithm = WeightCompression(
239+
mode,
240+
ratio,
241+
group_size,
242+
ignored_scope,
243+
all_layers,
244+
sensitivity_metric,
245+
awq,
246+
subset_size,
247+
scale_estimation,
248+
gptq,
249+
lora_correction,
250+
backup_mode,
251+
compression_format,
252+
advanced_parameters,
253+
)
254+
255+
statistics_points = None
256+
if advanced_parameters and advanced_parameters.statistics_path:
257+
# If there is no such directory, then caches statistics
258+
statistics_path = Path(advanced_parameters.statistics_path)
259+
if not statistics_path.exists():
260+
cache_weight_compression_statistics(model, graph, dataset, subset_size, statistics_path)
261+
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
262+
compression_algorithm.set_backend_entity(model)
263+
_, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(graph)
264+
register_statistics_for_algorithm(
265+
statistics_aggregator,
266+
model,
267+
graph,
268+
compression_algorithm,
269+
matmul_input_to_output_nodes_map,
270+
)
271+
statistics_aggregator.load_statistics_from_dir(statistics_path)
272+
statistics_points = statistics_aggregator.statistic_points
273+
274+
return compression_algorithm.apply(model, graph, statistics_points, dataset)

nncf/quantization/algorithms/weight_compression/algorithm.py

+4
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ def set_backend_entity(self, model: TModel) -> None:
323323
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend
324324

325325
self._backend_entity = FXWeightCompressionAlgoBackend()
326+
elif model_backend == BackendType.ONNX:
327+
from nncf.quantization.algorithms.weight_compression.onnx_backend import ONNXWeightCompressionAlgoBackend
328+
329+
self._backend_entity = ONNXWeightCompressionAlgoBackend()
326330
else:
327331
msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!"
328332
raise nncf.UnsupportedBackendError(msg)

0 commit comments

Comments
 (0)