Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support transposed input for data-aware Weights Compression #3296

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,15 @@ def get_statistic_points(
)
# Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden
# size dimension.
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0]
n_dims = len(output_edge.tensor_shape)
reduction_axes = tuple(
i
for i in range(n_dims)
if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if get_input_hidden_dim is positive?

Copy link
Contributor Author

@rk119 rk119 Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if get_input_hidden_dim is positive?

I created this backend function mainly because in comparison to OV, Torch tensors are reduced across all the dimensions except for the last one (transposed or not) hence -1 is returned by default. Only for OV the method is overriden, when transpose_a=True then -2 is returned leaving the second last dimension from being reduced since later in the opset.matmul operation it will be swapped with the last dim for transpose_a=True. In other case which is transpose_a=False it will return -1 as well.

I think you can use node.metatype.output_channel_axis directly without defining a backend method.

Unfortunately, node.metatype.output_channel_axis returns None over here hence I'm not sure how to utilize it:

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it should work for output_edge.to_node.

Suggest to keep get_input_hidden_dim, but probably rename to get_activation_channel_axis and implement for OpenVINO backend a bit more general to cover cases with Convolutions.

You can call this get_activation_channel_axis from here for OpenVINO: https://github.com/openvinotoolkit/nncf/blob/develop/nncf/openvino/graph/node_utils.py#L631

node = output_edge.to_node
input_port_id = node.input_port_id
input_shape = node.tensor_shape
input_channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port_id, input_shape)

For torch, this way should work:

output_edge.to_node.metatype.output_channel_axis

Copy link
Contributor Author

@rk119 rk119 Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I was trying to look for an existing implementation to obtain the dim but could not find it. My bad. Yep, this approach seems better.

)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size
reduction_axes=reduction_axes, subset_size=self._subset_size
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
10 changes: 10 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""

@staticmethod
def get_input_hidden_dim(input_node: NNCFNode) -> int:
"""
Returns the index of the hidden dimension in the shape of the input node.

:param input_node: The input node.
:return: The index of the hidden dimension in the shape of the input node.
"""
return -1


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
19 changes: 12 additions & 7 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,19 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor:
if node.metatype in self._backend_entity.convolution_metatypes:
msg = "Convolution metatypes are not supported"
raise nncf.UnsupportedModelError(msg)
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)

hidden_dim = self._backend_entity.get_input_hidden_dim(node)
hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
(inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]),
backend=inputs[0].backend,
dtype=TensorDataType.float32,
)

for inp in inputs:
batch_size = 1 if len(inp.shape) == 2 else inp.shape[0]
if node.metatype in self._backend_entity.matmul_metatypes:
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.reshape((-1, inp.shape[hidden_dim]))
inp = fns.transpose(inp)
hessian *= nsamples / (nsamples + batch_size)
nsamples += batch_size
Expand Down Expand Up @@ -267,8 +267,13 @@ def _quantize_weights(
scales.append(scale)
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
transpose = self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2
activations = (
[inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs]
if transpose
else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
)
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations, transpose)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def mean_statistic_collector(

@staticmethod
def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int:
if node.layer_attributes.input_attributes["transpose"]:
msg = "Transposed input is not supported"
raise nncf.UnsupportedModelError(msg)
constant_ports = node.layer_attributes.get_const_port_ids()
activation_ports = [
e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports
Expand Down Expand Up @@ -204,7 +201,12 @@ def insert_adapters(
A_W = opset.constant(lora_A.data)
B_W = opset.constant(lora_B.data)

A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True)
A_MM = opset.matmul(
input_node,
A_W,
transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"],
transpose_b=True,
)
B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True)

node_output_port = mm_node.output(0)
Expand Down Expand Up @@ -364,6 +366,12 @@ def filter_func(point: StatisticPoint) -> bool:

return filter_func

@staticmethod
def get_input_hidden_dim(node: NNCFNode) -> int:
if (node is not None) and (node.layer_attributes is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what is the valid case when this condition False? If it's invalid case, I'd not silently return -1.

Copy link
Contributor Author

@rk119 rk119 Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I did not have that conditional check and only returned this return -2 if node.layer_attributes.input_attributes["transpose"] else -1 which caused this test to fail as follows:

image

since while collecting statistics the node passed in get_input_hidden_dim was None hence could not perform the check node.layer_attributes.input_attributes["transpose"] and so it raised an error.

I agree returning -1 silently wouldn't seem ideal but in the current implementation it always assumes that the last dim is hidden by default and excludes it using n_dims - 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting, why this test fails. Can it be because node corresponds to embedding instead of matmul?
Probably, another approach will avoid this situation: #3296 (comment)

Copy link
Contributor Author

@rk119 rk119 Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, another approach will avoid this situation: #3296 (comment)

I implemented the approach specified in #3296 (comment) so far and yet the test is failing again in the checks.

That's interesting, why this test fails. Can it be because node corresponds to embedding instead of matmul?

I am not sure since I am unable to debug this test properly on my local machine due to the Windows fatal exception: access violation error.

Copy link
Contributor Author

@rk119 rk119 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ljaljushkin,

I was able to run the test on my local machine by downgrading onnx version as it was causing some errors. I noticed an inconsistency in the test failures.

Test run 1:

image

Passes for _test_basic_configurations but fails in _test_advanced_gptq_scale_estimation for configuration Testing: AWQ=True, GPTQ=False, Scale=True, LoRA=False

Consecutive Test run 2:

image

_test_basic_configurations fails now for configuration Testing: AWQ=True, Group=1, Ratio=0.4, Metric=hessian_input_activation, Scope=IgnoredScope(names=[], patterns=[], types=[], subgraphs=[], validate=True)

Both test runs fail with the same error:

image

Furthermore, I debugged the test failures and confirmed if the configurations are working as intended. The code seems to not have any flaws in the logic and those earlier failures execute properly in the debugger. It seems to be an issue with how statistics are cached internally (I am not exactly sure how to word it properly). I could be totally wrong about this, but the inconsistency seems to validate it for me so far. However, I am also not sure why this inconsistency is triggered now and not earlier before my changes were implemented. Would be grateful for some guidance and clarity on this :)

Copy link
Contributor Author

@rk119 rk119 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ljaljushkin Alright then. Before I implement and push those changes, I thought I could explain my findings before adding sort.

The NNCF graphs are different for the tests that pass vs the one that fails.

The following screenshot is of one MVN node with edges pointing to child nodes in the following order for the tests that pass:

image

The NNCF graph of the model which leads to a failure since this line of code accesses the first edge node which causes the error 'NoneType' object has no attribute 'get_const_port_ids' later:

image

The screenshots above are the following node with the last edge pointing to shape:

image
image

I also noticed in the NNCF graphs raw code that the nodes added are the same and differences arise when edges are added confirming that the order of nodes in which the edges were added was probably causing an issue.

image

I debugged further to see why the nodes are ordered the same, but edges lists are different.

I noticed that the child nodes returned for the same MVN node in 2 successful test runs are returned in a non-deterministic order in out.get_target_inputs() but are sorted here before adding to the graph:

image

image

Noticed the same child nodes returned here while adding the edges are not sorted hence why I made the present change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have order of edges deterministic!

But I doubt it will help resolve the original problem with the channel axis. IMO, it only solves the problem for this particular instance with MVN, but it doesn't necessarily guarantee that the first edge will always lead to a matmul with not empty layer_attributes.input_attributes, does it?

Copy link
Contributor Author

@rk119 rk119 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it doesn't necessarily guarantee that the first edge will always lead to a matmul with not empty layer_attributes.input_attributes, does it?

Ah yes! You are right about it not always leading to matmul :) I'll make the relevant changes and push them in sometime.

Good to have order of edges deterministic!

So, would you recommend I keep these changes or revert them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you could add this to avoid confusion between different graphs for the same model. Hope it doesn't significantly affect graph building for large models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I just made a minor change to keep it consistent with the sort condition for adding the nodes.

return -2 if node.layer_attributes.input_attributes["transpose"] else -1
return -1


class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def calculate_quantization_params(
return result_scale, zp

@staticmethod
def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic:
def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> WCTensorStatistic:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd pass output_channel_axis explicitly and re-implement definition of reduction_shape using this info.

"""
Mimic the activation reducing logic from WeightCompression.get_statistic_points.

Expand All @@ -376,7 +376,9 @@ def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic
shapes = []
for act in activations:
shapes.append(act.shape)
reduction_shape = tuple(range(act.ndim - 1))
reduction_shape = (
tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1))
)
mean_values.append(fns.mean(act, axis=reduction_shape))
wc_statistics = WCTensorStatistic(mean_values, shapes)
return wc_statistics
Expand Down
23 changes: 11 additions & 12 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,21 +1440,20 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs):
)
def test_compression_with_transposed_activations(kwargs):
dataset_size = 4
model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model
model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model
input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size
dataset = Dataset(input_data)

with pytest.raises(nncf.UnsupportedModelError):
Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's some issue with transpose_a=False, transpose_b=False, could you please make the test more general to cover all 4 combinations, as follows?

from contextlib import nullcontext


@pytest.mark.parametrize(
    ("transpose_a", "transpose_b", "raises_error"),
    (
        (False, True, False),
        (True, True, False),
        (False, False, True),
        (True, False, True),
    ),
    ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"]
)
@pytest.mark.parametrize(
    "kwargs",
    (
        dict(scale_estimation=True),
        dict(lora_correction=True),
        dict(
            gptq=True,
            awq=True,
            scale_estimation=True,
            advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)),
        ),
    ),
    ids=['se', 'lora', 'gptq_se_awq']
)
def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs):
    dataset_size = 4
    model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model
    input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size
    dataset = Dataset(input_data)

    with pytest.raises(nncf.UnsupportedModelError) if raises_error else nullcontext():
        compress_weights(
            model,
            mode=CompressWeightsMode.INT4_SYM,
            ratio=1.0,
            group_size=8,
            subset_size=2,
            dataset=dataset,
            all_layers=True,
            **kwargs,
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For kwargs such as dict(lora_correction=True) and dict(awq=True), it seems that the following cases:

@pytest.mark.parametrize(
    ("transpose_a", "transpose_b", "raises_error"),
    (
        (False, False, True),
        (True, False, True),
    ),
)

do not raise an error and seem to pass for those two algorithms. Would you suggest I still implement a check to raise an unsupported error for these algorithms when transpose_b=False or?

Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with transpose_b=False wasn't mentioned in GFI, so I can't insist on implementing it within this PR.
However, it would be wonderful if you could resolve both issues! :) In that case, raise_error won't be needed, and test wouldn't expect an exception for all combinations.

Copy link
Contributor

@ljaljushkin ljaljushkin Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need in template test: #3230 (comment)
and dict(awq=True) is not needed as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with transpose_b=False wasn't mentioned in GFI, so I can't insist on implementing it within this PR. However, it would be wonderful if you could resolve both issues! :) In that case, raise_error won't be needed, and test wouldn't expect an exception for all combinations.

So, should I try implementing transpose_b=False in this PR or that would be out of scope? I wouldn't mind either :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both options are possible. The choice is yours.

But regardless of your choice, please extend the test to verify all combinations as originally suggested.
If you decide not to implement transpose_b=False, it's fine, the test will raise an error to ensure the known issue isn't forgotten.
If you fix the issue for transpose_b=False, then all combinations should pass without expecting an error.

compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)
compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=8,
subset_size=2,
dataset=dataset,
all_layers=True,
**kwargs,
)


class TestOVTemplateWeightCompression(TemplateWeightCompression):
Expand Down