Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Support weight quantization for operations where weight_port_id != 1 #3334

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
27 changes: 18 additions & 9 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.model_graph_manager import get_weight_channel_axes
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
@@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point

@staticmethod
def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]:
# TODO(dlyakhov): support transpose conv and other cases
return (0,)
return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id)

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]:
@@ -177,16 +177,25 @@ def _get_input_scale_shape(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)

if is_weights:
# TODO(dlyakhov): support transpose conv/ make channel_idx common
channel_idx = 0
node = nncf_graph.get_node_by_name(target_point.target_node_name)
channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id)
else:
channel_idx = 1 # channel dim for activations
channel_axes = [1]

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)
channel_idx = channel_axes[0] if channel_axes else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
channel_idx = channel_axes[0] if channel_axes else 0

Since channel axes is already being checked and handled in the if-else block below. channel_axes[0] can directly be passed to channel_idx parameter of get_scale_shape

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to return channel_idx in this function, so I think it is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay


if not len(channel_axes):
scale_shape = (1,)
else:
scale_shape = tuple(get_scale_shape(
input_shape,
is_weights=is_weights,
per_channel=per_channel,
channel_idx=channel_idx
))

return input_shape, scale_shape, channel_idx

Loading