From ada9c0a002a11dc828053690e87341304682cdcd Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Mon, 17 Feb 2025 20:35:20 +0400 Subject: [PATCH 01/16] Lora correction input transpose support --- .../algorithms/weight_compression/algorithm.py | 7 +++++-- .../algorithms/weight_compression/openvino_backend.py | 5 +---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index fa07273fa69..0cbd1ee113b 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -776,9 +776,12 @@ def get_statistic_points( ) # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden # size dimension. - n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) + output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] + transpose = output_edge.to_node.layer_attributes.input_attributes['transpose'] + n_dims = len(output_edge.tensor_shape) + reduction_axes = tuple(range(n_dims - 1)) if not transpose else tuple(i for i in range(n_dims) if i != n_dims - 2) stat_collector = self._backend_entity.mean_statistic_collector( - reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size + reduction_axes=reduction_axes, subset_size=self._subset_size ) statistic_container.add_statistic_point( StatisticPoint( diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 81212bb36fa..f1d999ee8e7 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -110,9 +110,6 @@ def mean_statistic_collector( @staticmethod def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) constant_ports = node.layer_attributes.get_const_port_ids() activation_ports = [ e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports @@ -204,7 +201,7 @@ def insert_adapters( A_W = opset.constant(lora_A.data) B_W = opset.constant(lora_B.data) - A_MM = opset.matmul(input_node, A_W, transpose_a=False, transpose_b=True) + A_MM = opset.matmul(input_node, A_W, transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes['transpose'], transpose_b=True) B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True) node_output_port = mm_node.output(0) From a18ba70c5f98a3a7400afad08eb5d9d31a164323 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Mon, 24 Feb 2025 22:39:44 +0400 Subject: [PATCH 02/16] OV backend act transpose support --- .../algorithms/weight_compression/gptq.py | 13 +++++------ .../weight_compression/scale_estimation.py | 4 ++-- .../quantization/test_weights_compression.py | 23 +++++++++---------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index ddb5b83b1ae..1222764c614 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -170,19 +170,17 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: if node.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - if node.layer_attributes.input_attributes["transpose"]: - msg = "Transposed input is not supported" - raise nncf.UnsupportedModelError(msg) + hidden_dim = -2 if node.layer_attributes.input_attributes['transpose'] else -1 hessian = fns.zeros( - (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 + (inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]), backend=inputs[0].backend, dtype=TensorDataType.float32 ) for inp in inputs: batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.reshape((-1, inp.shape[hidden_dim])) inp = fns.transpose(inp) hessian *= nsamples / (nsamples + batch_size) nsamples += batch_size @@ -267,8 +265,9 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] - wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations) + transpose = wc_params.node_with_weight.layer_attributes.input_attributes['transpose'] + activations = [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] if transpose else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] + wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations, transpose) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 6bfbaebbd83..64eeb978d24 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -365,7 +365,7 @@ def calculate_quantization_params( return result_scale, zp @staticmethod - def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic: + def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> WCTensorStatistic: """ Mimic the activation reducing logic from WeightCompression.get_statistic_points. @@ -376,7 +376,7 @@ def activations_to_wc_statistics(activations: List[Tensor]) -> WCTensorStatistic shapes = [] for act in activations: shapes.append(act.shape) - reduction_shape = tuple(range(act.ndim - 1)) + reduction_shape = tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1)) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index ffe12ae02dc..aba99c6247d 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -1481,21 +1481,20 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) def test_compression_with_transposed_activations(kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=True, transpose_b=False).ov_model + model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) - with pytest.raises(nncf.UnsupportedModelError): - compress_weights( - model, - mode=CompressWeightsMode.INT4_SYM, - ratio=1.0, - group_size=8, - subset_size=2, - dataset=dataset, - all_layers=True, - **kwargs, - ) + compress_weights( + model, + mode=CompressWeightsMode.INT4_SYM, + ratio=1.0, + group_size=8, + subset_size=2, + dataset=dataset, + all_layers=True, + **kwargs, + ) class TestOVTemplateWeightCompression(TemplateWeightCompression): From d69ff3fcedce034657a09780c115ceaf7073536f Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Mon, 24 Feb 2025 22:58:41 +0400 Subject: [PATCH 03/16] pre-commit fix --- .../algorithms/weight_compression/algorithm.py | 6 ++++-- .../algorithms/weight_compression/gptq.py | 14 ++++++++++---- .../weight_compression/openvino_backend.py | 7 ++++++- .../weight_compression/scale_estimation.py | 4 +++- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 75c0e103a2f..3ac0bf4860b 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -773,9 +773,11 @@ def get_statistic_points( # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden # size dimension. output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] - transpose = output_edge.to_node.layer_attributes.input_attributes['transpose'] + transpose = output_edge.to_node.layer_attributes.input_attributes["transpose"] n_dims = len(output_edge.tensor_shape) - reduction_axes = tuple(range(n_dims - 1)) if not transpose else tuple(i for i in range(n_dims) if i != n_dims - 2) + reduction_axes = ( + tuple(range(n_dims - 1)) if not transpose else tuple(i for i in range(n_dims) if i != n_dims - 2) + ) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 1222764c614..0da0573ad50 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -171,9 +171,11 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - hidden_dim = -2 if node.layer_attributes.input_attributes['transpose'] else -1 + hidden_dim = -2 if node.layer_attributes.input_attributes["transpose"] else -1 hessian = fns.zeros( - (inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]), backend=inputs[0].backend, dtype=TensorDataType.float32 + (inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]), + backend=inputs[0].backend, + dtype=TensorDataType.float32, ) for inp in inputs: @@ -265,8 +267,12 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - transpose = wc_params.node_with_weight.layer_attributes.input_attributes['transpose'] - activations = [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] if transpose else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] + transpose = wc_params.node_with_weight.layer_attributes.input_attributes["transpose"] + activations = ( + [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] + if transpose + else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] + ) wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations, transpose) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index f1d999ee8e7..20b21a8f906 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -201,7 +201,12 @@ def insert_adapters( A_W = opset.constant(lora_A.data) B_W = opset.constant(lora_B.data) - A_MM = opset.matmul(input_node, A_W, transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes['transpose'], transpose_b=True) + A_MM = opset.matmul( + input_node, + A_W, + transpose_a=wc_params.node_with_weight.layer_attributes.input_attributes["transpose"], + transpose_b=True, + ) B_MM = opset.matmul(A_MM, B_W, transpose_a=False, transpose_b=True) node_output_port = mm_node.output(0) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 64eeb978d24..af89c6bf3c7 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -376,7 +376,9 @@ def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> shapes = [] for act in activations: shapes.append(act.shape) - reduction_shape = tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1)) + reduction_shape = ( + tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1)) + ) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics From d493314ae8201f88814e831b6792fac97c7ba684 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:06:55 +0400 Subject: [PATCH 04/16] Brute force solution --- nncf/quantization/algorithms/weight_compression/algorithm.py | 5 +---- nncf/quantization/algorithms/weight_compression/backend.py | 4 ++++ nncf/quantization/algorithms/weight_compression/gptq.py | 4 ++-- .../algorithms/weight_compression/openvino_backend.py | 4 ++++ 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 3ac0bf4860b..57b29437114 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -773,11 +773,8 @@ def get_statistic_points( # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden # size dimension. output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] - transpose = output_edge.to_node.layer_attributes.input_attributes["transpose"] n_dims = len(output_edge.tensor_shape) - reduction_axes = ( - tuple(range(n_dims - 1)) if not transpose else tuple(i for i in range(n_dims) if i != n_dims - 2) - ) + reduction_axes = tuple(i for i in range(n_dims) if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node)) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index c8ea964a288..184bca2734d 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -246,6 +246,10 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> :return: Backend-specific callable to filter statistic containers according to its statistic point. """ + @staticmethod + def get_input_hidden_dim(input_node: NNCFNode) -> int: + return -1 + class AWQAlgoBackend(WeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 0da0573ad50..0b2b0afaa06 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -171,7 +171,7 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - hidden_dim = -2 if node.layer_attributes.input_attributes["transpose"] else -1 + hidden_dim = self._backend_entity.get_input_hidden_dim(node) hessian = fns.zeros( (inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]), backend=inputs[0].backend, @@ -267,7 +267,7 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - transpose = wc_params.node_with_weight.layer_attributes.input_attributes["transpose"] + transpose = True if self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2 else False activations = ( [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] if transpose diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 20b21a8f906..60ffcc919e3 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -365,6 +365,10 @@ def filter_func(point: StatisticPoint) -> bool: ) return filter_func + + @staticmethod + def get_input_hidden_dim(node: NNCFNode) -> int: + return -2 if node.layer_attributes.input_attributes['transpose'] else -1 class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend): From 202815b0e641e49e6fc0854692154b3bc15663c0 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:09:39 +0400 Subject: [PATCH 05/16] Minor fix --- nncf/quantization/algorithms/weight_compression/algorithm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 57b29437114..caa1fe08f06 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -772,8 +772,7 @@ def get_statistic_points( ) # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden # size dimension. - output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] - n_dims = len(output_edge.tensor_shape) + n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) reduction_axes = tuple(i for i in range(n_dims) if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node)) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size From 1dc14ba72fdc28f943b0cd7cdec63332b9a3b9fa Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:13:30 +0400 Subject: [PATCH 06/16] pre-commit fix --- .../algorithms/weight_compression/algorithm.py | 9 +++++++-- .../algorithms/weight_compression/backend.py | 2 +- nncf/quantization/algorithms/weight_compression/gptq.py | 2 +- .../algorithms/weight_compression/openvino_backend.py | 4 ++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index caa1fe08f06..91e932df636 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -772,8 +772,13 @@ def get_statistic_points( ) # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden # size dimension. - n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) - reduction_axes = tuple(i for i in range(n_dims) if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node)) + output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] + n_dims = len(output_edge.tensor_shape) + reduction_axes = tuple( + i + for i in range(n_dims) + if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node) + ) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index 184bca2734d..3c1ab739e6d 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -249,7 +249,7 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> @staticmethod def get_input_hidden_dim(input_node: NNCFNode) -> int: return -1 - + class AWQAlgoBackend(WeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 0b2b0afaa06..7d425a3618c 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -267,7 +267,7 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - transpose = True if self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2 else False + transpose = self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2 activations = ( [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] if transpose diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 60ffcc919e3..8540b1a34c5 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -365,10 +365,10 @@ def filter_func(point: StatisticPoint) -> bool: ) return filter_func - + @staticmethod def get_input_hidden_dim(node: NNCFNode) -> int: - return -2 if node.layer_attributes.input_attributes['transpose'] else -1 + return -2 if node.layer_attributes.input_attributes["transpose"] else -1 class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend): From 50eddb42a5b216bc971335008883cb895a3f4a0d Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 27 Feb 2025 23:10:45 +0400 Subject: [PATCH 07/16] attempt fix --- .../algorithms/weight_compression/openvino_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 8540b1a34c5..52197894107 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -368,7 +368,9 @@ def filter_func(point: StatisticPoint) -> bool: @staticmethod def get_input_hidden_dim(node: NNCFNode) -> int: - return -2 if node.layer_attributes.input_attributes["transpose"] else -1 + if (node is not None) and (node.layer_attributes is not None): + return -2 if node.layer_attributes.input_attributes["transpose"] else -1 + return -1 class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend): From 9dccf216453eb268de0d43891f4a801ef7fcfc31 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Sun, 2 Mar 2025 17:14:37 +0400 Subject: [PATCH 08/16] Add doc string --- nncf/quantization/algorithms/weight_compression/backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index 3c1ab739e6d..efb0ed7d004 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -248,6 +248,12 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> @staticmethod def get_input_hidden_dim(input_node: NNCFNode) -> int: + """ + Returns the index of the hidden dimension in the shape of the input node. + + :param input_node: The input node. + :return: The index of the hidden dimension in the shape of the input node. + """ return -1 From aa062dc27d07af18abfa7319b78d2a924738057e Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Wed, 5 Mar 2025 05:43:04 +0400 Subject: [PATCH 09/16] Implement get_activation_channel_axis --- .../weight_compression/algorithm.py | 14 +++++----- .../algorithms/weight_compression/backend.py | 12 +++++---- .../algorithms/weight_compression/gptq.py | 26 ++++++++++++------- .../weight_compression/openvino_backend.py | 7 +++-- .../weight_compression/scale_estimation.py | 6 ++--- .../weight_compression/torch_backend.py | 4 +++ .../weight_compression/torch_fx_backend.py | 4 +++ .../quantization/test_weights_compression.py | 1 + 8 files changed, 44 insertions(+), 30 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index d9babef85b4..fbbcfe7292a 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -772,15 +772,15 @@ def get_statistic_points( statistic_point = self._backend_entity.target_point( TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) - # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden - # size dimension. + # Reduce activations across all but the hidden dimension. output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] - n_dims = len(output_edge.tensor_shape) - reduction_axes = tuple( - i - for i in range(n_dims) - if i != n_dims + self._backend_entity.get_input_hidden_dim(output_edge.to_node) + node = output_edge.to_node + input_shape = output_edge.tensor_shape + input_channel_axis = self._backend_entity.get_activation_channel_axis( + node, self._backend_entity.get_activation_port_id(node, graph), input_shape ) + n_dims = len(input_shape) + reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis}) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index efb0ed7d004..1e5bd9aed62 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -247,14 +247,16 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> """ @staticmethod - def get_input_hidden_dim(input_node: NNCFNode) -> int: + @abstractmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: """ - Returns the index of the hidden dimension in the shape of the input node. + Returns axis number of the activation tensor which correspond to it channel. - :param input_node: The input node. - :return: The index of the hidden dimension in the shape of the input node. + :param node: NNCFNode instance. + :param port_id: Port ID for input. + :param input_shape: Shape of the input. + :return: Channel axis number. """ - return -1 class AWQAlgoBackend(WeightCompressionAlgoBackend): diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 7d425a3618c..ca0e2c730c4 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -123,8 +123,13 @@ def apply( ]: continue _, input_tensors = next(iter(inputs.items())) - hessian = self._calculate_hessian(node, input_tensors) - scale, zero_point = self._quantize_weights(model, graph, wc_params, hessian, input_tensors) + input_channel_axis = self._backend_entity.get_activation_channel_axis( + node, self._backend_entity.get_activation_port_id(node, graph), input_tensors[0].shape + ) + hessian = self._calculate_hessian(node, input_tensors, input_channel_axis) + scale, zero_point = self._quantize_weights( + model, graph, wc_params, hessian, input_tensors, input_channel_axis + ) scales[wc_params.weight_name] = scale zero_points[wc_params.weight_name] = zero_point @@ -157,7 +162,7 @@ def get_statistic_points( return self._layerwise_engine.get_statistic_points(model, graph, filtered_nodes) - def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: + def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor], input_channel_axis: int) -> Tensor: """ Calculates the Hessian matrix for the given node and inputs. @@ -171,9 +176,8 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: msg = "Convolution metatypes are not supported" raise nncf.UnsupportedModelError(msg) - hidden_dim = self._backend_entity.get_input_hidden_dim(node) hessian = fns.zeros( - (inputs[0].shape[hidden_dim], inputs[0].shape[hidden_dim]), + (inputs[0].shape[input_channel_axis], inputs[0].shape[input_channel_axis]), backend=inputs[0].backend, dtype=TensorDataType.float32, ) @@ -182,7 +186,7 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: batch_size = 1 if len(inp.shape) == 2 else inp.shape[0] if node.metatype in self._backend_entity.matmul_metatypes: if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[hidden_dim])) + inp = inp.reshape((-1, inp.shape[input_channel_axis])) inp = fns.transpose(inp) hessian *= nsamples / (nsamples + batch_size) nsamples += batch_size @@ -198,6 +202,7 @@ def _quantize_weights( wc_params: WeightCompressionParameters, hessian: Tensor, inputs: List[Tensor], + input_channel_axis: int, ): """ Quantizes the weights of the model based on the calculated Hessian matrix. @@ -267,13 +272,14 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - transpose = self._backend_entity.get_input_hidden_dim(wc_params.node_with_weight) == -2 activations = ( - [inp[:, (i1 + i) : (i1 + i + group_size), ...] for inp in inputs] - if transpose + [inp[..., (i1 + i) : (i1 + i + group_size), :] for inp in inputs] + if input_channel_axis != (len(inputs[0].shape) - 1) else [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] ) - wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations, transpose) + wc_statistics = ScaleEstimation.activations_to_wc_statistics( + activations, input_channel_axis + ) scale, zero_point = ScaleEstimation.calculate_quantization_params( wc_statistics, weight_tensor[:, (i1 + i) : (i1 + i + group_size)], diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index c2a4e9663c5..71c844be530 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -33,6 +33,7 @@ from nncf.openvino.graph.model_transformer import OVModelTransformer from nncf.openvino.graph.node_utils import convert_op from nncf.openvino.graph.node_utils import create_ov_const_from_tensor +from nncf.openvino.graph.node_utils import get_activation_channel_axis from nncf.openvino.graph.node_utils import get_const_value_as_numpy_tensor from nncf.openvino.graph.node_utils import get_const_value_as_ov_tensor from nncf.openvino.graph.node_utils import get_weight_channel_axes @@ -366,10 +367,8 @@ def filter_func(point: StatisticPoint) -> bool: return filter_func @staticmethod - def get_input_hidden_dim(node: NNCFNode) -> int: - if (node is not None) and (node.layer_attributes is not None): - return -2 if node.layer_attributes.input_attributes["transpose"] else -1 - return -1 + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: + return get_activation_channel_axis(node, port_id, input_shape) class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend): diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index af89c6bf3c7..b41d3e91306 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -365,7 +365,7 @@ def calculate_quantization_params( return result_scale, zp @staticmethod - def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> WCTensorStatistic: + def activations_to_wc_statistics(activations: List[Tensor], input_channel_axis: int) -> WCTensorStatistic: """ Mimic the activation reducing logic from WeightCompression.get_statistic_points. @@ -376,9 +376,7 @@ def activations_to_wc_statistics(activations: List[Tensor], transpose: bool) -> shapes = [] for act in activations: shapes.append(act.shape) - reduction_shape = ( - tuple(i for i in range(act.ndim) if i != act.ndim - 2) if transpose else tuple(range(act.ndim - 1)) - ) + reduction_shape = tuple(set(range(len(act.shape))) - {input_channel_axis}) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 6001a4c9025..e4d66ce24eb 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -326,6 +326,10 @@ def transform_model( return transformed_model + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: + return node.metatype.output_channel_axis + class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend): @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index 06cd8b845e8..b7e45366eeb 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -266,3 +266,7 @@ def transform_model( transformed_model = FXModelTransformer(model).transform(transformation_layout) return transformed_model + + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: + return node.metatype.output_channel_axis diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index c002df37932..81fdaebf08a 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -1437,6 +1437,7 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), ), ], + ids=["se", "lora", "gptq_se_awq"], ) def test_compression_with_transposed_activations(kwargs): dataset_size = 4 From 051a18356da213c72855f7d29a956edccebd345e Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Wed, 5 Mar 2025 11:55:48 +0400 Subject: [PATCH 10/16] fix test --- tests/openvino/native/quantization/test_gptq.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/openvino/native/quantization/test_gptq.py b/tests/openvino/native/quantization/test_gptq.py index a141d7c99fe..8d12703a720 100644 --- a/tests/openvino/native/quantization/test_gptq.py +++ b/tests/openvino/native/quantization/test_gptq.py @@ -346,7 +346,10 @@ def test_calculate_scale_linear(): nodes = graph.get_all_nodes() wrapped_inputs = [Tensor(inp) for inp in inputs] - H = gptq._calculate_hessian(nodes[1], wrapped_inputs) + input_channel_axis = gptq._backend_entity.get_activation_channel_axis( + nodes[1], gptq._backend_entity.get_activation_port_id(nodes[1], graph), wrapped_inputs[0].shape + ) + H = gptq._calculate_hessian(nodes[1], wrapped_inputs, input_channel_axis) ref_H = ref_gptq.H.numpy() assert np.all(np.isclose(ref_H, H.data)) @@ -356,7 +359,7 @@ def test_calculate_scale_linear(): ) wc_params.compression_config = WeightCompressionConfig(mode=CompressWeightsMode.INT4_SYM, group_size=16) - scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs) + scale, _ = gptq._quantize_weights(ov_model, graph, wc_params, H, wrapped_inputs, input_channel_axis) ref_scale = ref_scale.numpy() scale = scale.reshape(ref_scale.shape) assert np.all(np.isclose(ref_scale, scale.data)) From 7f7f4685e272f14ac2c714a026b80cf41da1ba36 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:06:55 +0400 Subject: [PATCH 11/16] Fix error --- nncf/quantization/algorithms/weight_compression/algorithm.py | 2 +- .../algorithms/weight_compression/scale_estimation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index fbbcfe7292a..170d581d949 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -780,7 +780,7 @@ def get_statistic_points( node, self._backend_entity.get_activation_port_id(node, graph), input_shape ) n_dims = len(input_shape) - reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis}) + reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis % n_dims}) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index b41d3e91306..f71ea1df43a 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -376,7 +376,7 @@ def activations_to_wc_statistics(activations: List[Tensor], input_channel_axis: shapes = [] for act in activations: shapes.append(act.shape) - reduction_shape = tuple(set(range(len(act.shape))) - {input_channel_axis}) + reduction_shape = tuple(set(range(len(act.shape))) - {input_channel_axis % len(act.shape)}) mean_values.append(fns.mean(act, axis=reduction_shape)) wc_statistics = WCTensorStatistic(mean_values, shapes) return wc_statistics From e19804c6a2736ca7f7210dcd76bce9a66d0e8a14 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Fri, 14 Mar 2025 01:20:58 +0400 Subject: [PATCH 12/16] Fix OV NNCF Graph Builder add edges --- nncf/openvino/graph/nncf_graph_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/openvino/graph/nncf_graph_builder.py b/nncf/openvino/graph/nncf_graph_builder.py index a4df03ccd7e..a50f1570dde 100644 --- a/nncf/openvino/graph/nncf_graph_builder.py +++ b/nncf/openvino/graph/nncf_graph_builder.py @@ -97,7 +97,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None: in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id for output_port_id, out in enumerate(op.outputs()): node_vs_target_inputs = defaultdict(list) - for inp in out.get_target_inputs(): + for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_type_name()): node_vs_target_inputs[inp.get_node()].append(inp) for out_node, inputs in node_vs_target_inputs.items(): From 80163998a67bc1856b931181b709067372a20a30 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Tue, 18 Mar 2025 19:13:55 +0400 Subject: [PATCH 13/16] Update --- nncf/openvino/graph/nncf_graph_builder.py | 2 +- .../weight_compression/algorithm.py | 29 +++++++++---------- .../weight_compression/mixed_precision.py | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/nncf/openvino/graph/nncf_graph_builder.py b/nncf/openvino/graph/nncf_graph_builder.py index a50f1570dde..14f9e19b3cc 100644 --- a/nncf/openvino/graph/nncf_graph_builder.py +++ b/nncf/openvino/graph/nncf_graph_builder.py @@ -97,7 +97,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None: in_node_id = graph.get_node_by_name(op.get_friendly_name()).node_id for output_port_id, out in enumerate(op.outputs()): node_vs_target_inputs = defaultdict(list) - for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_type_name()): + for inp in sorted(out.get_target_inputs(), key=lambda inp: inp.get_node().get_friendly_name()): node_vs_target_inputs[inp.get_node()].append(inp) for out_node, inputs in node_vs_target_inputs.items(): diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 170d581d949..f14c03e1bc9 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -668,19 +668,22 @@ def apply( ) return transformed_model - def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]: + def _get_activation_node_port_and_channel(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]: """ - This method returns the activation layer and corresponding port id for the node. + This method returns the activation layer, corresponding port id and channel axis for the given node. :param node: NNCFGraph node for which the activation is sought. :param nncf_graph: NNCFGraph instance with the node. - :return: Tuple with the activation node and port id. + :return: Tuple with the activation node, port id and channel axis. """ activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph) activation_edge = nncf_graph.get_input_edge_by_port_id(node, activation_port) activation_node = activation_edge.from_node port_id = activation_edge.output_port_id - return activation_node, port_id + activation_channel_axis = self._backend_entity.get_activation_channel_axis( + node, port_id, activation_edge.tensor_shape + ) + return activation_node, port_id, activation_channel_axis def get_matmul_input_to_output_nodes_map( self, matmul_nodes: List[NNCFNode], graph: NNCFGraph @@ -701,8 +704,8 @@ def get_matmul_input_to_output_nodes_map( """ matmul_input_to_output_nodes_map = defaultdict(list) for node in matmul_nodes: - act_node, output_port_id = self._get_activation_node_and_port(node, graph) - matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node) + act_node, output_port_id, act_channel_axis = self._get_activation_node_port_and_channel(node, graph) + matmul_input_to_output_nodes_map[(act_node, output_port_id, act_channel_axis)].append(node) return matmul_input_to_output_nodes_map def get_compression_nodes_info( @@ -768,19 +771,13 @@ def get_statistic_points( statistic_container = StatisticPointsContainer() # Statistics for data aware algorithms if self._data_aware_compression: - for node, output_port_id in nodes_and_port_ids: + for node, output_port_id, channel_axis in nodes_and_port_ids: statistic_point = self._backend_entity.target_point( TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) # Reduce activations across all but the hidden dimension. - output_edge = graph.get_output_edges_by_port_id(node, output_port_id)[0] - node = output_edge.to_node - input_shape = output_edge.tensor_shape - input_channel_axis = self._backend_entity.get_activation_channel_axis( - node, self._backend_entity.get_activation_port_id(node, graph), input_shape - ) - n_dims = len(input_shape) - reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis % n_dims}) + n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) + reduction_axes = tuple(set(range(n_dims)) - {channel_axis % n_dims}) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) @@ -819,7 +816,7 @@ def _get_statistics_for_weights_compression( # Where mean_value is a 1D tensor representing an activation reduced over batch and sequence length dimensions, # shape is an original shape of an activation before reduction, n is the size of the dataset (or subset_size). statistics = {} - for (act_node, output_port_id), matmul_nodes in matmul_input_to_output_nodes_map.items(): + for (act_node, output_port_id, _), matmul_nodes in matmul_input_to_output_nodes_map.items(): tensor_collectors = list( statistic_points.get_algo_statistics_for_node( act_node.node_name, diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index c8f5f175d6f..1840d1fb681 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -269,7 +269,7 @@ def get_statistic_points( self._set_backend_entity(model) statistic_container = StatisticPointsContainer() - for act_node, output_port_id in nodes_and_port_ids: + for act_node, output_port_id, _ in nodes_and_port_ids: n_dims = len(graph.get_output_edges_by_port_id(act_node, output_port_id)[0].tensor_shape) if n_dims < 2: msg = ( From 298b587ebec1c938eade284ba32a6205ed29114d Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Tue, 18 Mar 2025 19:23:08 +0400 Subject: [PATCH 14/16] Update --- nncf/quantization/algorithms/weight_compression/algorithm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 032015e102d..e73ae7c3c9d 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -789,13 +789,13 @@ def get_statistic_points( statistic_container = StatisticPointsContainer() # Statistics for data aware algorithms if self._data_aware_compression: - for node, output_port_id, channel_axis in nodes_and_port_ids: + for node, output_port_id, input_channel_axis in nodes_and_port_ids: statistic_point = self._backend_entity.target_point( TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) # Reduce activations across all but the hidden dimension. n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) - reduction_axes = tuple(set(range(n_dims)) - {channel_axis % n_dims}) + reduction_axes = tuple(set(range(n_dims)) - {input_channel_axis % n_dims}) stat_collector = self._backend_entity.mean_statistic_collector( reduction_axes=reduction_axes, subset_size=self._subset_size ) From 9a7cb0259a2c659905910bc2982847d1b35d8ccd Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 20 Mar 2025 00:31:36 +0400 Subject: [PATCH 15/16] Update --- .../algorithms/weight_compression/gptq.py | 4 +- .../weight_compression/scale_estimation.py | 7 +++ .../quantization/test_weights_compression.py | 50 +++++++++++++------ 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index ca0e2c730c4..927e6c40571 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -215,10 +215,10 @@ def _quantize_weights( """ if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" - raise RuntimeError(msg) + raise nncf.UnsupportedModelError(msg) if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: msg = "Transpose is not supported" - raise RuntimeError(msg) + raise nncf.UnsupportedModelError(msg) weight_tensor = self._backend_entity.get_weight( wc_params.node_with_weight, wc_params.weight_port_id, model, graph diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 6add4fdc018..eb3d0a77582 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -117,6 +117,13 @@ def apply( scales, zero_points = dict(), dict() for wp in track(all_weight_params, description="Applying Scale Estimation"): + if ( + wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes + and not wp.node_with_weight.layer_attributes.constant_attributes[wp.weight_port_id]["transpose"] + ): + msg = "Transpose is not supported" + raise nncf.UnsupportedModelError(msg) + weight_name = wp.weight_name node_name = wp.node_with_weight.node_name config = wp.compression_config diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 050130132cf..1def6570173 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -11,7 +11,8 @@ import inspect import os -from typing import Callable, Dict, List +from contextlib import nullcontext +from typing import Callable, Dict, List, Optional import numpy as np import openvino.runtime as ov @@ -89,7 +90,9 @@ class LMLinearModel(OVReferenceModel): HIDDEN_DIM = 16 INPUT_SHAPE = [1, 24, HIDDEN_DIM] # [B, SeqLen, HiddenDim] - def _create_ov_model(self, transpose_b: bool = True, transpose_a=False, input_shape=None): + def _create_ov_model( + self, transpose_b: bool = True, transpose_a: bool = False, input_shape: Optional[List[int]] = None + ): self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape hdim_axis = -2 if transpose_a else -1 self._hidden_dim = self._input_shape[hdim_axis] @@ -1455,9 +1458,19 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) +@pytest.mark.parametrize( + ("transpose_a", "transpose_b", "raises_error"), + ( + (False, True, False), + (True, True, False), + (False, False, True), + (True, False, True), + ), + ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"], +) @pytest.mark.parametrize( "kwargs", - [ + ( dict(scale_estimation=True), dict(lora_correction=True), dict( @@ -1466,25 +1479,30 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): scale_estimation=True, advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), ), - ], + ), ids=["se", "lora", "gptq_se_awq"], ) -def test_compression_with_transposed_activations(kwargs): +def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model + model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) - compress_weights( - model, - mode=CompressWeightsMode.INT4_SYM, - ratio=1.0, - group_size=8, - subset_size=2, - dataset=dataset, - all_layers=True, - **kwargs, - ) + with ( + pytest.raises(nncf.UnsupportedModelError) + if raises_error and not kwargs.get("lora_correction", False) + else nullcontext() + ): + compress_weights( + model, + mode=CompressWeightsMode.INT4_SYM, + ratio=1.0, + group_size=8, + subset_size=2, + dataset=dataset, + all_layers=True, + **kwargs, + ) class TestOVTemplateWeightCompression(TemplateWeightCompression): From 341c4a849b3fee9c36129a67790ed79c68b1a3e0 Mon Sep 17 00:00:00 2001 From: Riffat Khan <83776178+rk119@users.noreply.github.com> Date: Thu, 20 Mar 2025 01:16:17 +0400 Subject: [PATCH 16/16] Update --- .../algorithms/weight_compression/gptq.py | 4 +- .../weight_compression/scale_estimation.py | 7 ---- .../quantization/test_weights_compression.py | 40 ++++++------------- 3 files changed, 14 insertions(+), 37 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 927e6c40571..ca0e2c730c4 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -215,10 +215,10 @@ def _quantize_weights( """ if wc_params.node_with_weight.metatype in self._backend_entity.convolution_metatypes: msg = "Convolution metatypes are not supported" - raise nncf.UnsupportedModelError(msg) + raise RuntimeError(msg) if not wc_params.node_with_weight.layer_attributes.constant_attributes[wc_params.weight_port_id]["transpose"]: msg = "Transpose is not supported" - raise nncf.UnsupportedModelError(msg) + raise RuntimeError(msg) weight_tensor = self._backend_entity.get_weight( wc_params.node_with_weight, wc_params.weight_port_id, model, graph diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index eb3d0a77582..6add4fdc018 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -117,13 +117,6 @@ def apply( scales, zero_points = dict(), dict() for wp in track(all_weight_params, description="Applying Scale Estimation"): - if ( - wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes - and not wp.node_with_weight.layer_attributes.constant_attributes[wp.weight_port_id]["transpose"] - ): - msg = "Transpose is not supported" - raise nncf.UnsupportedModelError(msg) - weight_name = wp.weight_name node_name = wp.node_with_weight.node_name config = wp.compression_config diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 1def6570173..58bbc4a99ec 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -11,7 +11,6 @@ import inspect import os -from contextlib import nullcontext from typing import Callable, Dict, List, Optional import numpy as np @@ -1458,16 +1457,6 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ) -@pytest.mark.parametrize( - ("transpose_a", "transpose_b", "raises_error"), - ( - (False, True, False), - (True, True, False), - (False, False, True), - (True, False, True), - ), - ids=["tb_nota", "ta_tb", "nota_notb", "ta_notb"], -) @pytest.mark.parametrize( "kwargs", ( @@ -1482,27 +1471,22 @@ def test_compression_with_different_algo_combinations(input_shape, kwargs): ), ids=["se", "lora", "gptq_se_awq"], ) -def test_compression_with_transpose(transpose_a, transpose_b, raises_error, kwargs): +def test_compression_with_transpose(kwargs): dataset_size = 4 - model = LMLinearModel(transpose_a=transpose_a, transpose_b=transpose_b).ov_model + model = LMLinearModel(transpose_a=True, transpose_b=True).ov_model input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size dataset = Dataset(input_data) - with ( - pytest.raises(nncf.UnsupportedModelError) - if raises_error and not kwargs.get("lora_correction", False) - else nullcontext() - ): - compress_weights( - model, - mode=CompressWeightsMode.INT4_SYM, - ratio=1.0, - group_size=8, - subset_size=2, - dataset=dataset, - all_layers=True, - **kwargs, - ) + compress_weights( + model, + mode=CompressWeightsMode.INT4_SYM, + ratio=1.0, + group_size=8, + subset_size=2, + dataset=dataset, + all_layers=True, + **kwargs, + ) class TestOVTemplateWeightCompression(TemplateWeightCompression):