Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX]: Add support for data-free Weight Compression Algorithm (#3273) #3346

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions nncf/onnx/graph/layer_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional

from nncf.common.graph.layer_attributes import BaseLayerAttributes


class ONNXLayerAttributes(BaseLayerAttributes):
"""
This class stores additional information about nodes that needs to be processed during compression.
"""

def __init__(
self,
constant_attributes: Dict[int, Any],
layer_attributes: Optional[BaseLayerAttributes] = None,
inputs_attributes: Optional[Dict[Any, Any]] = None,
):
"""
:param constant_attributes: Map of weights port ID to corresponding const attributes.
:param layer_attributes: Map of weights port ID to corresponding common layer attributes.
:param inputs_attributes: Activation attributes.
"""
self._constant_attributes = constant_attributes
self._layer_attributes = layer_attributes
self._inputs_attributes = inputs_attributes

@property
def constant_attributes(self) -> Dict[int, Any]:
return self._constant_attributes

@property
def layer_attributes(self) -> Optional[BaseLayerAttributes]:
return self._layer_attributes

@property
def input_attributes(self) -> Optional[Dict[Any, Any]]:
return self._inputs_attributes

def get_const_port_ids(self) -> List[int]:
"""
Returns indices of input ports corresponding to the constant nodes.

:returns: List of input port indices with constants.
"""
if self._constant_attributes is not None:
return list(self._constant_attributes.keys())
return []
137 changes: 137 additions & 0 deletions nncf/onnx/graph/layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Tuple

from nncf.common.graph.graph import NNCFNode
from nncf.onnx.graph.layer_attributes import ONNXLayerAttributes
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGroupConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype


class ONNXLayoutElem(Enum):
"""
Layout elements descriptor for convolutional and linear onnx layers:
C_IN: Input channels dimension.
C_OUT: Output channels dimension.
SPATIAL: Spatial dimension.
GROUPS: Groups dimension.
"""

C_IN = "channels_in"
C_OUT = "channels_out"
SPATIAL = "spatial"
GROUPS = "groups"


_CONV_BASE_CONST_LAYOUT = {
ONNXConvolutionMetatype: (ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
ONNXDepthwiseConvolutionMetatype: (ONNXLayoutElem.GROUPS, ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
ONNXGroupConvolutionMetatype: (ONNXLayoutElem.GROUPS, ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN),
}


def get_conv_weights_layout_from_node(node: NNCFNode) -> Tuple[ONNXLayoutElem]:
"""
Calculates weights layout for a target convolution node.

:param node: Target convolution node.
:return: Target convolution Node weights layout.
"""
layer_attributes = node.layer_attributes
port_id = _get_constant_port_id_from_layer_attributes(layer_attributes)
return get_conv_weights_layout(
ONNX_metatype=node.metatype, weights_shape=layer_attributes.constant_attributes[port_id]["shape"]
)


def get_linear_weights_layout_from_node(node: NNCFNode) -> Tuple[ONNXLayoutElem]:
"""
Calculates weights layout for a target linear node.

:param node: Target linear node.
:return: Target linear Node weight layout.
"""
layer_attributes = node.layer_attributes
port_id = _get_constant_port_id_from_layer_attributes(layer_attributes)
constant_layer_attrs = layer_attributes.constant_attributes[port_id]
return get_linear_input_layout(
input_shape=constant_layer_attrs["shape"],
transpose=constant_layer_attrs["transpose"],
port_id=port_id,
)


def get_linear_activations_layout_from_node(
node: NNCFNode, port_id: int, input_shape: Tuple[int]
) -> Tuple[ONNXLayoutElem]:
"""
Calculates activations layout for a target linear node.

:param node: Target linear node.
:param port_id: Target input port ID.
:param input_shape: Shape of the input.
:return: Target linear Node weight layout.
"""
act_layer_attrs = node.layer_attributes.input_attributes
return get_linear_input_layout(
input_shape=input_shape,
transpose=act_layer_attrs["transpose"],
port_id=port_id,
)


def get_conv_weights_layout(ONNX_metatype: ONNXOpMetatype, weights_shape: Tuple[int, ...]) -> Tuple[ONNXLayoutElem]:
"""
Calculates weights layout for a target convolution node.

:param ONNX_metatype: Target convolution node OpenVINO metatype.
:param weights_shape: Shape of the target convolution node weight.
:return: Target convolution node weights layout.
"""
base_layout = _CONV_BASE_CONST_LAYOUT[ONNX_metatype]
kernel_size = weights_shape[len(base_layout) :]
weights_layout = list(base_layout) + [ONNXLayoutElem.SPATIAL] * len(kernel_size)
return tuple(weights_layout)


def get_linear_input_layout(input_shape: Tuple[int, ...], transpose: bool, port_id: int) -> Tuple[ONNXLayoutElem]:
"""
Calculates input layout for a target linear node.

:param input_shape: Shape of the target linear node input.
:param port_id: Port id of the target linear node input.
:return: Target linear node input layout.
"""
input_layout = [ONNXLayoutElem.SPATIAL] * (len(input_shape) - 2)
if len(input_shape) > 1:
if (transpose and port_id == 0) or (not transpose and port_id == 1):
input_layout += [ONNXLayoutElem.C_IN, ONNXLayoutElem.C_OUT]
else:
input_layout += [ONNXLayoutElem.C_OUT, ONNXLayoutElem.C_IN]
else:
input_layout += [ONNXLayoutElem.C_IN]
return tuple(input_layout)


def _get_constant_port_id_from_layer_attributes(layer_attributes: ONNXLayerAttributes) -> int:
"""
Returns constant ports id for convolutional and linear ops layer attributes.

:param layer_attributes: Target convolutional/linear layer op layer attributes.
:return: Constant port id for the target convolutional/linear model.
"""
port_ids = list(layer_attributes.constant_attributes.keys())
assert len(port_ids) == 1
return port_ids[0]
6 changes: 6 additions & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@
onnx_metatypes.ONNXROIAlignMetatype,
onnx_metatypes.ONNXEmbeddingMetatype,
]

CONV_OPERATIONS = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
onnx_metatypes.ONNXGroupConvolutionMetatype,
]
46 changes: 45 additions & 1 deletion nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, List

import numpy as np
import onnx
Expand All @@ -18,10 +18,20 @@
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.logging.logger import nncf_logger
from nncf.onnx.graph.layout import OVLayoutElem
from nncf.onnx.graph.layout import get_conv_weights_layout
from nncf.onnx.graph.layout import get_conv_weights_layout_from_node
from nncf.onnx.graph.layout import get_linear_activations_layout_from_node
from nncf.onnx.graph.layout import get_linear_input_layout
from nncf.onnx.graph.layout import get_linear_weights_layout_from_node
from nncf.onnx.graph.metatypes import onnx_metatypes as om
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDequantizeLinearMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXMatMulMetatype
from nncf.onnx.graph.onnx_helper import get_tensor_value
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
from nncf.onnx.graph.metatypes.groups import CONV_OPERATIONS
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS


def is_node_with_bias(node: NNCFNode) -> bool:
Expand Down Expand Up @@ -139,6 +149,36 @@ def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
weight_channel_axis = -1 - port_id if transpose else -2 + port_id
return weight_channel_axis

def get_weight_channel_axes(node: NNCFNode) -> List[int]:
"""
Returns axes numbers of the weight tensor which correspond to its channels.

:param node: NNCFNode with weights.
:param weights_port_id: Weight port id of the target node.
:return: Axes numbers of the weight tensor which correspond to its channels.
"""
if node.metatype not in OPERATIONS_WITH_WEIGHTS:
msg = "Channel axis cannot be defined for operation without weights."
raise ValueError(msg)

if node.metatype in CONV_OPERATIONS:
weights_layout = get_conv_weights_layout_from_node(node)
return [idx for idx, elem in enumerate(weights_layout) if elem in [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT]]
elif node.metatype == ONNXMatMulMetatype:
return get_matmul_channel_axes(node)
return node.metatype.const_channel_axis


def get_matmul_channel_axes(node: ov.Node) -> List[int]:
"""
Calculate channel axes for the MatMul operation.

:param node: The target node.
:return: List of channel axes for the MatMul operation.
"""
weights_layout = get_linear_weights_layout_from_node(node)
return [idx for idx, elem in enumerate(weights_layout) if elem in [OVLayoutElem.SPATIAL, OVLayoutElem.C_OUT]]


def get_act_quantization_axis(node: NNCFNode, port_id: int) -> int:
"""
Expand Down Expand Up @@ -214,3 +254,7 @@ def get_quantized_tensor_shape(
if target_point.is_weight_target_point():
return node.layer_attributes.weight_attrs[target_point.port_id]["shape"]
return _get_activation_tensor_shape(nncf_graph, node, target_point)


def get_const_value_as_onnx_tensor(initializer_name: str, model: onnx.ModelProto) -> np.ndarray:
# TODO
71 changes: 71 additions & 0 deletions nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union

import onnx

import nncf
from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.logging.logger import nncf_logger
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import BackupMode
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.parameters import DropType
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import SensitivityMetric
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedAccuracyRestorerParameters
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.algorithms.accuracy_control.algorithm import QuantizationAccuracyRestorer
from nncf.quantization.algorithms.accuracy_control.algorithm import calculate_accuracy_drop
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.quantization.quantize_model import quantize_with_tune_hyperparams
from nncf.quantization.quantize_model import warning_model_no_batchwise_support
from nncf.quantization.statistics_caching import cache_weight_compression_statistics
from nncf.quantization.statistics_caching import register_statistics_for_algorithm
from nncf.scopes import IgnoredScope

TTensor = TypeVar("TTensor")
Expand Down Expand Up @@ -201,3 +212,63 @@ def quantize_with_accuracy_control_impl(
)

return quantized_model

def compress_weights_impl(
model: onnx.ModelProto,
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
group_size: int,
ignored_scope: IgnoredScope,
all_layers: bool,
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
backup_mode: BackupMode,
compression_format: CompressionFormat,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> onnx.ModelProto:
"""
Implementation of the `compress_weights()` method for the OpenVINO backend.
"""
graph = NNCFGraphFactory.create(model)
compression_algorithm = WeightCompression(
mode,
ratio,
group_size,
ignored_scope,
all_layers,
sensitivity_metric,
awq,
subset_size,
scale_estimation,
gptq,
lora_correction,
backup_mode,
compression_format,
advanced_parameters,
)

statistics_points = None
if advanced_parameters and advanced_parameters.statistics_path:
# If there is no such directory, then caches statistics
statistics_path = Path(advanced_parameters.statistics_path)
if not statistics_path.exists():
cache_weight_compression_statistics(model, graph, dataset, subset_size, statistics_path)
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
compression_algorithm.set_backend_entity(model)
_, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(graph)
register_statistics_for_algorithm(
statistics_aggregator,
model,
graph,
compression_algorithm,
matmul_input_to_output_nodes_map,
)
statistics_aggregator.load_statistics_from_dir(statistics_path)
statistics_points = statistics_aggregator.statistic_points

return compression_algorithm.apply(model, graph, statistics_points, dataset)
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend

self._backend_entity = FXWeightCompressionAlgoBackend()
elif model_backend == BackendType.ONNX:
from nncf.quantization.algorithms.weight_compression.onnx_backend import ONNXWeightCompressionAlgoBackend

self._backend_entity = ONNXWeightCompressionAlgoBackend()
else:
msg = f"Cannot return backend-specific entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)
Expand Down
Loading
Loading