From b8203a539e476832d94b11ae6dff8d08f768616d Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Fri, 7 Mar 2025 20:59:59 -0500 Subject: [PATCH 01/11] Support weight channel axes --- .../algorithms/min_max/torch_fx_backend.py | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 564d6b12a1d..84cc9a25c49 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -40,6 +40,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.model_graph_manager import get_weight_channel_axes from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.layers import QUANTIZATION_MODULES @@ -177,16 +178,35 @@ def _get_input_scale_shape( nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: is_weights = target_point.is_weight_target_point() + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + if is_weights: - # TODO(dlyakhov): support transpose conv/ make channel_idx common - channel_idx = 0 + node = nncf_graph.get_node_by_name(target_point.target_node_name) + # Get channel axes considering weight port ID + channel_axes = get_weight_channel_axes( + node.metatype, + len(input_shape), + target_point.input_port_id + ) + if channel_axes: + channel_idx = channel_axes[0] + scale_shape = tuple(get_scale_shape( + input_shape, + is_weights=True, + per_channel=per_channel, + channel_idx=channel_idx + )) + else: + scale_shape = (1,) + channel_idx = 0 else: - channel_idx = 1 # channel dim for activations - - input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) - scale_shape = tuple( - get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) - ) + channel_idx = 1 + scale_shape = tuple(get_scale_shape( + input_shape, + is_weights=False, + per_channel=per_channel, + channel_idx=channel_idx + )) return input_shape, scale_shape, channel_idx From 474e6b77859a6e87ae3c81469c1eeff4fc3aeec9 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Fri, 7 Mar 2025 21:01:24 -0500 Subject: [PATCH 02/11] Change minmax algo to support channel axes for ConvTranspose --- .../algorithms/min_max/algorithm.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 220422dad92..7248b856e43 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -444,29 +444,31 @@ def _get_stat_collector( qconfig: QuantizerConfig, batchwise_statistics: bool, ) -> TensorCollector: - """ - Creates and returns a statistic collector based on the quantizer's configuration. - :param graph: NNCFGraph instance. - :param target_point: Target point indicates where statistics should be collected. - :param qconfig: Configuration of a quantizer layer, - defining the configuration of created statistic collector. - :param batchwise_statistics: Determines whether quantizer statistics should be calculated - for each item of the batch or for the entire batch. - :return: Statistic Collector. - """ is_weight = target_point.is_weight_target_point() node = graph.get_node_by_name(target_point.target_node_name) shape = self._backend_entity.get_target_point_shape(graph, node, target_point) - range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) - + + # Get channel axes considering ConvTranspose layers channel_axes = () if qconfig.per_channel: - channel_axes = ( - self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,) - ) + if is_weight: + if node.metatype.__name__.startswith("PTConvTranspose"): + channel_axes = (1,) # Output channels for transpose conv + else: + channel_axes = self._backend_entity.get_weight_quantization_axes( + node, target_point, len(shape) + ) + else: + channel_axes = (1,) - # Weight statistics is constant, so only one collection is enough. + # Align statistics collection with scale shape + reduction_axes, aggregation_axes = None, None + if shape is not None and channel_axes: + all_axes = set(range(len(shape))) + reduction_axes = tuple(all_axes - set(channel_axes)) + + range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) num_samples = self._subset_size if not is_weight else 1 batchwise_statistics = batchwise_statistics and not is_weight From 7319f9fd632cb90805a95ebf50fb636183caeb16 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Fri, 7 Mar 2025 21:36:43 -0500 Subject: [PATCH 03/11] add comment back algorithm.py --- nncf/quantization/algorithms/min_max/algorithm.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 7248b856e43..7a9d573d675 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -444,7 +444,16 @@ def _get_stat_collector( qconfig: QuantizerConfig, batchwise_statistics: bool, ) -> TensorCollector: + """ + Creates and returns a statistic collector based on the quantizer's configuration. + :param nncf_graph: NNCFGraph instance. + :param target_point: Target point indicates where statistics should be collected. + :param quantizer_config: Configuration of a quantizer layer, + defining the configuration of created statistic collector. + :param num_samples: Number of samples to collect from the 'target_point'. + :return: Statistic Collector. + """ is_weight = target_point.is_weight_target_point() node = graph.get_node_by_name(target_point.target_node_name) shape = self._backend_entity.get_target_point_shape(graph, node, target_point) From 3308b40ec9a35d00ad781b3c57d790b7d8f5950c Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Fri, 7 Mar 2025 21:45:08 -0500 Subject: [PATCH 04/11] add comment algorithm.py --- .../algorithms/min_max/algorithm.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 7a9d573d675..128bd6c00b5 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -439,9 +439,9 @@ def _get_range_estimator_parameters( def _get_stat_collector( self, - graph: NNCFGraph, + nncf_graph: NNCFGraph, target_point: TargetPoint, - qconfig: QuantizerConfig, + quantizer_config: QuantizerConfig, batchwise_statistics: bool, ) -> TensorCollector: """ @@ -451,16 +451,17 @@ def _get_stat_collector( :param target_point: Target point indicates where statistics should be collected. :param quantizer_config: Configuration of a quantizer layer, defining the configuration of created statistic collector. - :param num_samples: Number of samples to collect from the 'target_point'. - :return: Statistic Collector. + :param batchwise_statistics: Determines whether quantizer statistics should be calculated + for each item of the batch or for the entire batch. + :return: TensorCollector for the statistics calculation. """ is_weight = target_point.is_weight_target_point() - node = graph.get_node_by_name(target_point.target_node_name) - shape = self._backend_entity.get_target_point_shape(graph, node, target_point) + node = nncf_graph.get_node_by_name(target_point.target_node_name) + shape = self._backend_entity.get_target_point_shape(nncf_graph, node, target_point) # Get channel axes considering ConvTranspose layers channel_axes = () - if qconfig.per_channel: + if quantizer_config.per_channel: if is_weight: if node.metatype.__name__.startswith("PTConvTranspose"): channel_axes = (1,) # Output channels for transpose conv @@ -477,13 +478,13 @@ def _get_stat_collector( all_axes = set(range(len(shape))) reduction_axes = tuple(all_axes - set(channel_axes)) - range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) + range_estimator_params = self._get_range_estimator_parameters(target_point, quantizer_config) num_samples = self._subset_size if not is_weight else 1 batchwise_statistics = batchwise_statistics and not is_weight collector_params = RangeInitCollectorParams( - is_weights=is_weight, scheme=qconfig.mode, per_channel=qconfig.per_channel + is_weights=is_weight, scheme=quantizer_config.mode, per_channel=quantizer_config.per_channel ) reduction_axes, aggregation_axes = None, None if shape is not None: From 0ee4f8b1a77422c9ac79b9f2e5f52b9ce71f8a4c Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Fri, 7 Mar 2025 21:52:43 -0500 Subject: [PATCH 05/11] refactor parameter names in --- .../algorithms/min_max/algorithm.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 128bd6c00b5..906fa6a1f0f 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -439,29 +439,29 @@ def _get_range_estimator_parameters( def _get_stat_collector( self, - nncf_graph: NNCFGraph, + graph: NNCFGraph, target_point: TargetPoint, - quantizer_config: QuantizerConfig, + qconfig: QuantizerConfig, batchwise_statistics: bool, ) -> TensorCollector: """ Creates and returns a statistic collector based on the quantizer's configuration. - :param nncf_graph: NNCFGraph instance. + :param graph: NNCFGraph instance. :param target_point: Target point indicates where statistics should be collected. - :param quantizer_config: Configuration of a quantizer layer, + :param qconfig: Configuration of a quantizer layer, defining the configuration of created statistic collector. :param batchwise_statistics: Determines whether quantizer statistics should be calculated for each item of the batch or for the entire batch. - :return: TensorCollector for the statistics calculation. + :return: Statistic Collector. """ is_weight = target_point.is_weight_target_point() - node = nncf_graph.get_node_by_name(target_point.target_node_name) - shape = self._backend_entity.get_target_point_shape(nncf_graph, node, target_point) + node = graph.get_node_by_name(target_point.target_node_name) + shape = self._backend_entity.get_target_point_shape(graph, node, target_point) # Get channel axes considering ConvTranspose layers channel_axes = () - if quantizer_config.per_channel: + if qconfig.per_channel: if is_weight: if node.metatype.__name__.startswith("PTConvTranspose"): channel_axes = (1,) # Output channels for transpose conv @@ -478,13 +478,13 @@ def _get_stat_collector( all_axes = set(range(len(shape))) reduction_axes = tuple(all_axes - set(channel_axes)) - range_estimator_params = self._get_range_estimator_parameters(target_point, quantizer_config) + range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) num_samples = self._subset_size if not is_weight else 1 batchwise_statistics = batchwise_statistics and not is_weight collector_params = RangeInitCollectorParams( - is_weights=is_weight, scheme=quantizer_config.mode, per_channel=quantizer_config.per_channel + is_weights=is_weight, scheme=qconfig.mode, per_channel=qconfig.per_channel ) reduction_axes, aggregation_axes = None, None if shape is not None: From 74f677a17621bbac965c1b51cbcf9776e669540f Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Mon, 10 Mar 2025 19:01:57 -0400 Subject: [PATCH 06/11] use torch weight_channel_axes in torchfx --- .../algorithms/min_max/torch_fx_backend.py | 37 +++++++------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 84cc9a25c49..1229b1452e4 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -150,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point @staticmethod def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: - # TODO(dlyakhov): support transpose conv and other cases - return (0,) + return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id) @staticmethod def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: @@ -179,32 +178,22 @@ def _get_input_scale_shape( ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: is_weights = target_point.is_weight_target_point() input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) - + if is_weights: node = nncf_graph.get_node_by_name(target_point.target_node_name) - # Get channel axes considering weight port ID - channel_axes = get_weight_channel_axes( - node.metatype, - len(input_shape), - target_point.input_port_id - ) - if channel_axes: - channel_idx = channel_axes[0] - scale_shape = tuple(get_scale_shape( - input_shape, - is_weights=True, - per_channel=per_channel, - channel_idx=channel_idx - )) - else: - scale_shape = (1,) - channel_idx = 0 + channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id) + else: + channel_axes = [1] + + channel_idx = channel_axes[0] if channel_axes else 0 + + if is_weights and not channel_axes: + scale_shape = (1,) else: - channel_idx = 1 scale_shape = tuple(get_scale_shape( - input_shape, - is_weights=False, - per_channel=per_channel, + input_shape, + is_weights=is_weights, + per_channel=per_channel, channel_idx=channel_idx )) From 3565ba3fbd970b49410d1ca4942ec16ed0439836 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Mon, 10 Mar 2025 19:02:54 -0400 Subject: [PATCH 07/11] streamline channel axes handling --- nncf/quantization/algorithms/min_max/algorithm.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 906fa6a1f0f..4357e45fbfd 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -459,28 +459,15 @@ def _get_stat_collector( node = graph.get_node_by_name(target_point.target_node_name) shape = self._backend_entity.get_target_point_shape(graph, node, target_point) - # Get channel axes considering ConvTranspose layers channel_axes = () if qconfig.per_channel: if is_weight: - if node.metatype.__name__.startswith("PTConvTranspose"): - channel_axes = (1,) # Output channels for transpose conv - else: - channel_axes = self._backend_entity.get_weight_quantization_axes( - node, target_point, len(shape) - ) + channel_axes = self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) else: channel_axes = (1,) - # Align statistics collection with scale shape - reduction_axes, aggregation_axes = None, None - if shape is not None and channel_axes: - all_axes = set(range(len(shape))) - reduction_axes = tuple(all_axes - set(channel_axes)) - range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) num_samples = self._subset_size if not is_weight else 1 - batchwise_statistics = batchwise_statistics and not is_weight collector_params = RangeInitCollectorParams( From 7efa67546396d58f0a8198cace758fd99a312128 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Thu, 20 Mar 2025 15:56:41 -0400 Subject: [PATCH 08/11] refactor channel axes handling --- nncf/quantization/algorithms/min_max/algorithm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 4357e45fbfd..21b051b1697 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -461,11 +461,11 @@ def _get_stat_collector( channel_axes = () if qconfig.per_channel: - if is_weight: - channel_axes = self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) - else: - channel_axes = (1,) + channel_axes = ( + self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,) + ) + # Weight statistics is constant, so only one collection is enough. range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) num_samples = self._subset_size if not is_weight else 1 batchwise_statistics = batchwise_statistics and not is_weight From e8ab53a769e26ae0630fa1ddf27f065c372399bd Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Thu, 20 Mar 2025 17:00:42 -0400 Subject: [PATCH 09/11] refactor channel axes --- nncf/quantization/algorithms/min_max/algorithm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 21b051b1697..0cf51ca209e 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -458,6 +458,7 @@ def _get_stat_collector( is_weight = target_point.is_weight_target_point() node = graph.get_node_by_name(target_point.target_node_name) shape = self._backend_entity.get_target_point_shape(graph, node, target_point) + range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) channel_axes = () if qconfig.per_channel: @@ -466,8 +467,8 @@ def _get_stat_collector( ) # Weight statistics is constant, so only one collection is enough. - range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) num_samples = self._subset_size if not is_weight else 1 + batchwise_statistics = batchwise_statistics and not is_weight collector_params = RangeInitCollectorParams( From 5bb69f78e49c5f622433e6778aa4b9661061ce57 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Thu, 20 Mar 2025 17:01:45 -0400 Subject: [PATCH 10/11] refactor channel axes condition --- nncf/quantization/algorithms/min_max/torch_fx_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 1229b1452e4..6c3f4c23252 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -187,7 +187,7 @@ def _get_input_scale_shape( channel_idx = channel_axes[0] if channel_axes else 0 - if is_weights and not channel_axes: + if not len(channel_axes): scale_shape = (1,) else: scale_shape = tuple(get_scale_shape( From 3104e58cf4d1269a7ae76e33b4e4da50c941dd77 Mon Sep 17 00:00:00 2001 From: Siddhant Chauhan Date: Thu, 20 Mar 2025 17:03:15 -0400 Subject: [PATCH 11/11] remove whitespace --- nncf/quantization/algorithms/min_max/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 0cf51ca209e..220422dad92 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -459,7 +459,7 @@ def _get_stat_collector( node = graph.get_node_by_name(target_point.target_node_name) shape = self._backend_entity.get_target_point_shape(graph, node, target_point) range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) - + channel_axes = () if qconfig.per_channel: channel_axes = (