diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 564d6b12a1d..6c3f4c23252 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -40,6 +40,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.model_graph_manager import get_weight_channel_axes from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.layers import QUANTIZATION_MODULES @@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point @staticmethod def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: - # TODO(dlyakhov): support transpose conv and other cases - return (0,) + return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id) @staticmethod def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: @@ -177,16 +177,25 @@ def _get_input_scale_shape( nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: is_weights = target_point.is_weight_target_point() + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + if is_weights: - # TODO(dlyakhov): support transpose conv/ make channel_idx common - channel_idx = 0 + node = nncf_graph.get_node_by_name(target_point.target_node_name) + channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id) else: - channel_idx = 1 # channel dim for activations + channel_axes = [1] - input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) - scale_shape = tuple( - get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) - ) + channel_idx = channel_axes[0] if channel_axes else 0 + + if not len(channel_axes): + scale_shape = (1,) + else: + scale_shape = tuple(get_scale_shape( + input_shape, + is_weights=is_weights, + per_channel=per_channel, + channel_idx=channel_idx + )) return input_shape, scale_shape, channel_idx