Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support transposed input for data-aware Weights Compression #3296

Open
wants to merge 24 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id
for output_port_id, out in enumerate(op.outputs()):
node_vs_target_inputs = defaultdict(list)
for inp in out.get_target_inputs():
for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_friendly_name()):
node_vs_target_inputs[inp.get_node()].append(inp)

for out_node, inputs in node_vs_target_inputs.items():
Expand Down
25 changes: 14 additions & 11 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,19 +686,22 @@ def apply(
)
return transformed_model

def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]:
def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]:
"""
This method returns the activation layer and corresponding port id for the node.
This method returns the activation layer, corresponding port id and channel axis for the given node.

:param node: NNCFGraph node for which the activation is sought.
:param nncf_graph: NNCFGraph instance with the node.
:return: Tuple with the activation node and port id.
:return: Tuple with the activation node, port id and channel axis.
"""
activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph)
activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port)
activation_node = activation_edge.from_node
port_id = activation_edge.output_port_id
return activation_node, port_id
activation_channel_axis = self._backend_entity.get_activation_channel_axis(
node, port_id, activation_edge.tensor_shape
)
return activation_node, port_id, activation_channel_axis

def get_matmul_input_to_output_nodes_map(
self, matmul_nodes: List[NNCFNode], graph: NNCFGraph
Expand All @@ -719,8 +722,8 @@ def get_matmul_input_to_output_nodes_map(
"""
matmul_input_to_output_nodes_map = defaultdict(list)
for node in matmul_nodes:
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node)
act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph)
matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node)
return matmul_input_to_output_nodes_map

def get_compression_nodes_info(
Expand Down Expand Up @@ -786,15 +789,15 @@ def get_statistic_points(
statistic_container = StatisticPointsContainer()
# Statistics for data aware algorithms
if self._data_aware_compression:
for node, output_port_id in nodes_and_port_ids:
for node, output_port_id, input_channel_axis in nodes_and_port_ids:
statistic_point = self._backend_entity.target_point(
TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id
)
# Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden
# size dimension.
# Reduce activations across all but the hidden dimension.
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis % n_dims})
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size
reduction_axes=reduction_axes, subset_size=self._subset_size
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down Expand Up @@ -831,7 +834,7 @@ def _get_statistics_for_weights_compression(
# Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions,
# shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size).
statistics = {}
for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items():
for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items():
tensor_collectors = list(
statistic_points.get_algo_statistics_for_node(
act_node.node_name,
Expand Down
12 changes: 12 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""

@staticmethod
@abstractmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.

:param node: NNCFNode instance.
:param port_id: Port ID for input.
:param input_shape: Shape of the input.
:return: Channel axis number.
"""


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
31 changes: 21 additions & 10 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,13 @@ def apply(
]:
continue
_, input_tensors = next(iter(inputs.items()))
hessian = self._calculate_hessian(node, input_tensors)
scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors)
input_channel_axis = self._backend_entity.get_activation_channel_axis(
node, self._backend_entity.get_activation_port_id(node, graph), input_tensors[0].shape
)
hessian = self._calculate_hessian(node, input_tensors, input_channel_axis)
scale, zero_point = self._quantize_weights(
model, graph, wc_params, hessian, input_tensors, input_channel_axis
)
scales[wc_params.weight_name] = scale
zero_points[wc_params.weight_name] = zero_point

Expand Down Expand Up @@ -157,7 +162,7 @@ def get_statistic_points(

return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes)

def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor:
def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor], input_channel_axis: int) -> Tensor:
"""
Calculates the Hessian matrix for the given node and inputs.

Expand All @@ -170,19 +175,18 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor:
if node.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise nncf.UnsupportedModelError(msg)
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(inputs[0].shape[input_channel_axis], inputs[0].shape[input_channel_axis]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.reshape((-1, inp.shape[input_channel_axis]))
inp = fns.transpose(inp)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
Expand All @@ -198,6 +202,7 @@ def _quantize_weights(
wc_params: WeightCompressionParameters,
hessian: Tensor,
inputs: List[Tensor],
input_channel_axis: int,
):
"""
Quantizes the weights of the model based on the calculated Hessian matrix.
Expand Down Expand Up @@ -267,8 +272,14 @@ def _quantize_weights(
scales.append(scale)
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
activations = (
[inp[..., (i1 + i) : (i1 + i + group_size), :] for inp in inputs]
if input_channel_axis != (len(inputs[0].shape) - 1)
else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
)
wc_statistics = ScaleEstimation.activations_to_wc_statistics(
activations, input_channel_axis
)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_statistic_points(
self._set_backend_entity(model)

statistic_container = StatisticPointsContainer()
for act_node, output_port_id in nodes_and_port_ids:
for act_node, output_port_id, _ in nodes_and_port_ids:
n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape)
if n_dims < 2:
msg = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nncf.openvino.graph.model_transformer import OVModelTransformer
from nncf.openvino.graph.node_utils import convert_op
from nncf.openvino.graph.node_utils import create_ov_const_from_tensor
from nncf.openvino.graph.node_utils import get_activation_channel_axis
from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor
from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor
from nncf.openvino.graph.node_utils import get_weight_channel_axes
Expand Down Expand Up @@ -113,9 +114,6 @@ def mean_statistic_collector(

@staticmethod
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)
constant_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
Expand Down Expand Up @@ -198,7 +196,12 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
A_MM = opset.matmul(
input_node,
A_W,
transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"],
transpose_b=True,
)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down Expand Up @@ -358,6 +361,10 @@ def filter_func(point: StatisticPoint) -> bool:

return filter_func

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
return get_activation_channel_axis(node, port_id, input_shape)


class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def calculate_quantization_params(
return result_scale, zp

@staticmethod
def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic:
def activations_to_wc_statistics(activations: List[Tensor], input_channel_axis: int) -> WCTensorStatistic:
"""
Mimic the activation reducing logic from WeightCompression.get_statistic_points.

Expand All @@ -368,7 +368,7 @@ def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic
shapes = []
for act in activations:
shapes.append(act.shape)
reduction_shape = tuple(range(act.ndim - 1))
reduction_shape = tuple(set(range(len(act.shape))) - {input_channel_axis % len(act.shape)})
mean_values.append(fns.mean(act, axis=reduction_shape))
wc_statistics = WCTensorStatistic(mean_values, shapes)
return wc_statistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ def transform_model(

return transformed_model

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
return node.metatype.output_channel_axis


class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,7 @@ def transform_model(
transformed_model = FXModelTransformer(model).transform(transformation_layout)

return transformed_model

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
return node.metatype.output_channel_axis
7 changes: 5 additions & 2 deletions tests/openvino/native/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def test_calculate_scale_linear():

nodes = graph.get_all_nodes()
wrapped_inputs = [Tensor(inp) for inp in inputs]
H = gptq._calculate_hessian(nodes[1], wrapped_inputs)
input_channel_axis = gptq._backend_entity.get_activation_channel_axis(
nodes[1], gptq._backend_entity.get_activation_port_id(nodes[1], graph), wrapped_inputs[0].shape
)
H = gptq._calculate_hessian(nodes[1], wrapped_inputs, input_channel_axis)

ref_H = ref_gptq.H.numpy()
assert np.all(np.isclose(ref_H, H.data))
Expand All @@ -356,7 +359,7 @@ def test_calculate_scale_linear():
)
wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16)

scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs)
scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs, input_channel_axis)
ref_scale = ref_scale.numpy()
scale = scale.reshape(ref_scale.shape)
assert np.all(np.isclose(ref_scale, scale.data))
Expand Down
36 changes: 19 additions & 17 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import inspect
import os
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -89,7 +89,9 @@ class LMLinearModel(OVReferenceModel):
HIDDEN_DIM = 16
INPUT_SHAPE = [1, 24, HIDDEN_DIM] # [B, SeqLen, HiddenDim]

def _create_ov_model(self, transpose_b: bool = True, transpose_a=False, input_shape=None):
def _create_ov_model(
self, transpose_b: bool = True, transpose_a: bool = False, input_shape: Optional[List[int]] = None
):
self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape
hdim_axis = -2 if transpose_a else -1
self._hidden_dim = self._input_shape[hdim_axis]
Expand Down Expand Up @@ -1457,7 +1459,7 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs):

@pytest.mark.parametrize(
"kwargs",
[
(
dict(scale_estimation=True),
dict(lora_correction=True),
dict(
Expand All @@ -1466,25 +1468,25 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs):
scale_estimation=True,
advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)),
),
],
),
ids=["se", "lora", "gptq_se_awq"],
)
def test_compression_with_transposed_activations(kwargs):
def test_compression_with_transpose(kwargs):
dataset_size = 4
model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model
model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model
input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size
dataset = Dataset(input_data)

with pytest.raises(nncf.UnsupportedModelError):
Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's some issue with transpose_a=False, transpose_b=False, could you please make the test more general to cover all 4 combinations, as follows?

from contextlib import nullcontext


@pytest.mark.parametrize(
    ("transpose_a", "transpose_b", "raises_error"),
    (
        (False, True, False),
        (True, True, False),
        (False, False, True),
        (True, False, True),
    ),
    ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"]
)
@pytest.mark.parametrize(
    "kwargs",
    (
        dict(scale_estimation=True),
        dict(lora_correction=True),
        dict(
            gptq=True,
            awq=True,
            scale_estimation=True,
            advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)),
        ),
    ),
    ids=['se', 'lora', 'gptq_se_awq']
)
def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs):
    dataset_size = 4
    model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model
    input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size
    dataset = Dataset(input_data)

    with pytest.raises(nncf.UnsupportedModelError) if raises_error else nullcontext():
        compress_weights(
            model,
            mode=CompressWeightsMode.INT4_SYM,
            ratio=1.0,
            group_size=8,
            subset_size=2,
            dataset=dataset,
            all_layers=True,
            **kwargs,
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For kwargs such as dict(lora_correction=True) and dict(awq=True), it seems that the following cases:

@pytest.mark.parametrize(
    ("transpose_a", "transpose_b", "raises_error"),
    (
        (False, False, True),
        (True, False, True),
    ),
)

do not raise an error and seem to pass for those two algorithms. Would you suggest I still implement a check to raise an unsupported error for these algorithms when transpose_b=False or?

Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with transpose_b=False wasn't mentioned in GFI, so I can't insist on implementing it within this PR.
However, it would be wonderful if you could resolve both issues! :) In that case, raise_error won't be needed, and test wouldn't expect an exception for all combinations.

Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need in template test: #3230 (comment)
and dict(awq=True) is not needed as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with transpose_b=False wasn't mentioned in GFI, so I can't insist on implementing it within this PR. However, it would be wonderful if you could resolve both issues! :) In that case, raise_error won't be needed, and test wouldn't expect an exception for all combinations.

So, should I try implementing transpose_b=False in this PR or that would be out of scope? I wouldn't mind either :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both options are possible. The choice is yours.

But regardless of your choice, please extend the test to verify all combinations as originally suggested.
If you decide not to implement transpose_b=False, it's fine, the test will raise an error to ensure the known issue isn't forgotten.
If you fix the issue for transpose_b=False, then all combinations should pass without expecting an error.

compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)
compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)


class TestOVTemplateWeightCompression(TemplateWeightCompression):
Expand Down