Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Support weight quantization for operations where weight_port_id != 1 #3334

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
24 changes: 18 additions & 6 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,15 +458,27 @@ def _get_stat_collector(
is_weight = target_point.is_weight_target_point()
node = graph.get_node_by_name(target_point.target_node_name)
shape = self._backend_entity.get_target_point_shape(graph, node, target_point)
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig)


# Get channel axes considering ConvTranspose layers
channel_axes = ()
if qconfig.per_channel:
channel_axes = (
self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,)
)
if is_weight:
if node.metatype.__name__.startswith("PTConvTranspose"):
channel_axes = (1,) # Output channels for transpose conv
else:
channel_axes = self._backend_entity.get_weight_quantization_axes(
node, target_point, len(shape)
)
else:
channel_axes = (1,)

# Align statistics collection with scale shape
reduction_axes, aggregation_axes = None, None
if shape is not None and channel_axes:
all_axes = set(range(len(shape)))
reduction_axes = tuple(all_axes - set(channel_axes))

# Weight statistics is constant, so only one collection is enough.
range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig)
num_samples = self._subset_size if not is_weight else 1

batchwise_statistics = batchwise_statistics and not is_weight
Expand Down
36 changes: 28 additions & 8 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.model_graph_manager import get_weight_channel_axes
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
Expand Down Expand Up @@ -177,16 +178,35 @@ def _get_input_scale_shape(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)

if is_weights:
# TODO(dlyakhov): support transpose conv/ make channel_idx common
channel_idx = 0
node = nncf_graph.get_node_by_name(target_point.target_node_name)
# Get channel axes considering weight port ID
channel_axes = get_weight_channel_axes(
node.metatype,
len(input_shape),
target_point.input_port_id
)
if channel_axes:
channel_idx = channel_axes[0]
scale_shape = tuple(get_scale_shape(
input_shape,
is_weights=True,
per_channel=per_channel,
channel_idx=channel_idx
))
else:
scale_shape = (1,)
channel_idx = 0
else:
channel_idx = 1 # channel dim for activations

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)
channel_idx = 1
scale_shape = tuple(get_scale_shape(
input_shape,
is_weights=False,
per_channel=per_channel,
channel_idx=channel_idx
))

return input_shape, scale_shape, channel_idx

Expand Down
Loading