Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNCF] (#3249) Remove backend-specific methods from common layer attributes #3287

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
868335b
[NNCF] Add get_weight_shape_legacy function (#3249)
shumaari Feb 16, 2025
f04a525
[NNCF] Add get_target_dim_for_compression_legacy function (#3249)
shumaari Feb 17, 2025
a94e2a1
[NNCF] Add get_bias_shape_legacy function (#3249)
shumaari Feb 17, 2025
e0db330
[NNCF] Experimental torch backend: Replace (#3249)
shumaari Feb 16, 2025
433ff6b
[NNCF] Experimental tensorflow backend: Replace (#3249)
shumaari Feb 16, 2025
e437274
[NNCF] Experimental common backend: Replace (#3249)
shumaari Feb 16, 2025
39efbbd
[NNCF] Torch backend: Replace (#3249)
shumaari Feb 17, 2025
eb7905f
[NNCF] Remove backend-specific methods from common layer attributes
shumaari Feb 17, 2025
58d5112
[NNCF] Missing import statements in utils.py
shumaari Feb 18, 2025
bdd8b61
[NNCF] Remove abstract methods from common layer attributes
shumaari Feb 21, 2025
6c3e101
[NNCF] Add get_num_filters_legacy function
shumaari Feb 21, 2025
53f1ede
[NNCF] Replace get_num_filters: experimental torch backend
shumaari Feb 21, 2025
948ae9f
[NNCF] Remove get_num_filters from common layer attributes
shumaari Feb 21, 2025
0425fbb
[NNCF] Update get_weight_shape_legacy function
shumaari Feb 21, 2025
b6f3a39
[NNCF] Update get_target_dim_for_compression_legacy function
shumaari Feb 21, 2025
02ecd45
[NNCF] Update get_bias_shape_legacy function
shumaari Feb 21, 2025
907f35f
Formatting changes due black and isort
shumaari Feb 21, 2025
b70c03a
Merge branch 'openvinotoolkit:develop' into 3249__remove_backend_spec…
shumaari Feb 22, 2025
d4493d5
[NNCF] Replace calls: Torch backend
shumaari Feb 23, 2025
c24ac61
Implement suggested changes in doc strings
shumaari Feb 24, 2025
7de7e8b
Implement suggested refactoring of code
shumaari Feb 24, 2025
0377d87
Implement suggested changes
shumaari Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
Expand Down Expand Up @@ -66,18 +65,6 @@ def __init__(self, weight_requires_grad: bool, dtype: Dtype = Dtype.FLOAT, with_
self.dtype = dtype
self.with_bias = with_bias

@abstractmethod
def get_weight_shape(self) -> List[int]:
pass

def get_num_filters(self) -> int:
weight_shape = self.get_weight_shape()
return weight_shape[self.get_target_dim_for_compression()]

@abstractmethod
def get_target_dim_for_compression(self) -> int:
pass


class GenericWeightedLayerAttributes(WeightedLayerAttributes):
"""
Expand All @@ -103,12 +90,6 @@ def __init__(
self.weight_shape = weight_shape
self.filter_dimension_idx = filter_dimension_idx

def get_weight_shape(self) -> List[int]:
return self.weight_shape

def get_target_dim_for_compression(self) -> int:
return 0


class LinearLayerAttributes(WeightedLayerAttributes):
def __init__(
Expand All @@ -129,15 +110,6 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

def get_weight_shape(self) -> List[int]:
return [self.out_features, self.in_features]

def get_bias_shape(self) -> int:
return self.out_features if self.with_bias is True else 0

def get_target_dim_for_compression(self) -> int:
return 0


class ConvolutionLayerAttributes(WeightedLayerAttributes):
def __init__(
Expand Down Expand Up @@ -179,17 +151,6 @@ def __init__(
self.padding_values = padding_values
self.output_padding_values = output_padding_values

def get_weight_shape(self) -> List[int]:
if not self.transpose:
return [self.out_channels, self.in_channels // self.groups, *self.kernel_size]
return [self.in_channels, self.out_channels // self.groups, *self.kernel_size]

def get_target_dim_for_compression(self) -> int:
# Always quantize per each "out" channel
if self.transpose:
return 1
return 0


class GroupNormLayerAttributes(WeightedLayerAttributes):
def __init__(self, weight_requires_grad: bool, num_channels: int, num_groups: int):
Expand All @@ -204,12 +165,6 @@ def __init__(self, weight_requires_grad: bool, num_channels: int, num_groups: in
self.num_channels = num_channels
self.num_groups = num_groups

def get_weight_shape(self) -> List[int]:
return [self.num_channels]

def get_target_dim_for_compression(self) -> int:
return 0


@dataclass
class ReshapeLayerAttributes(BaseLayerAttributes):
Expand Down
80 changes: 80 additions & 0 deletions nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes
from nncf.common.graph.layer_attributes import GroupNormLayerAttributes
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.logging import nncf_logger
from nncf.common.pruning.utils import traverse_function
Expand Down Expand Up @@ -132,3 +137,78 @@ def get_reduction_axes(
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_axes[channel_axis]
return tuple(reduction_axes)


def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[int]:
"""
Returns hard-coded weights shape layout for the given layer attributes.
Applicable only for eager PyTorch and Tensorflow models.

:param layer_attributes: Layer attributes of a NNCFNode.
:return: Weights shape layout.
"""
if isinstance(layer_attributes, LinearLayerAttributes):
return [layer_attributes.out_features, layer_attributes.in_features]

if isinstance(layer_attributes, ConvolutionLayerAttributes):
if not layer_attributes.transpose:
return [
layer_attributes.out_channels,
layer_attributes.in_channels // layer_attributes.groups,
*layer_attributes.kernel_size,
]
return [
layer_attributes.in_channels,
layer_attributes.out_channels // layer_attributes.groups,
*layer_attributes.kernel_size,
]

if isinstance(layer_attributes, GroupNormLayerAttributes):
return [layer_attributes.num_channels]

assert isinstance(layer_attributes, GenericWeightedLayerAttributes)
return layer_attributes.weight_shape


def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttributes) -> int:
"""
Returns hard-coded target dim for compression for the given layer attributes.
Applicable only for eager PyTorch and Tensorflow models.

:param layer_attributes: Layer attributes of a NNCFNode.
:return: Target dim for compression.
"""
if isinstance(layer_attributes, ConvolutionLayerAttributes):
# Always quantize per each "out" channel
return 1 if layer_attributes.transpose else 0

else:
assert isinstance(
layer_attributes, (GenericWeightedLayerAttributes, LinearLayerAttributes, GroupNormLayerAttributes)
)
return 0


def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
"""
Returns hard-coded bias shape for the given linear layer attributes.
Applicable only for eager PyTorch and Tensorflow models.

:param layer_attributes: Linear layer attributes of a NNCFNode.
:return: Correspondent bias shape.
"""
assert isinstance(layer_attributes, LinearLayerAttributes)
return layer_attributes.out_features if layer_attributes.with_bias is True else 0


def get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) -> int:
"""
Returns hard-coded number of filters for the given layer attribues.
Applicable only for eager PyTorch and Tensorflow models.

:param layer_attributes: Layer attributes of a NNCFNode.
:return: Correspondent number of filters.
"""
assert isinstance(layer_attributes, WeightedLayerAttributes)
weight_shape = get_weight_shape_legacy(layer_attributes)
return weight_shape[get_target_dim_for_compression_legacy(layer_attributes)]
3 changes: 2 additions & 1 deletion nncf/experimental/common/pruning/nodes_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm
from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry
from nncf.experimental.common.graph.netron import save_for_netron
Expand Down Expand Up @@ -76,7 +77,7 @@ def get_pruning_groups(
roots = {}
for node in all_nodes_to_prune:
assert isinstance(node.layer_attributes, (LinearLayerAttributes, ConvolutionLayerAttributes))
pruning_dim = node.layer_attributes.get_target_dim_for_compression()
pruning_dim = get_target_dim_for_compression_legacy(node.layer_attributes)
output_tensors_shapes = [x.tensor_shape for x in graph.get_output_edges(node)]
assert not len(set(output_tensors_shapes)) > 1, node.node_name
output_tensors_shape = output_tensors_shapes[0]
Expand Down
3 changes: 2 additions & 1 deletion nncf/experimental/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.utils import get_first_nodes_of_type
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.logging import nncf_logger
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import QuantizationPointId
Expand Down Expand Up @@ -133,7 +134,7 @@ def _get_tensor_specs(
assert len(metatype.weight_definitions) == 1

channel_axes = metatype.weight_definitions[0].channel_axes
weight_shape = node.layer_attributes.get_weight_shape()
weight_shape = get_weight_shape_legacy(node.layer_attributes)
tensor_specs.append((weight_shape, channel_axes))
else:
data_format = node.layer_attributes.get_data_format()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.utils import get_num_filters_legacy
from nncf.common.logging import nncf_logger
from nncf.common.pruning.clusterization import Cluster
from nncf.common.pruning.clusterization import Clusterization
Expand Down Expand Up @@ -1204,15 +1205,15 @@ def _create_dynamic_dw_conv_input_op(conv_layer_attrs: BaseLayerAttributes, node
def _create_dynamic_bn_input_op(generic_layer_attrs: BaseLayerAttributes, node_name: str) -> UpdateBatchNormParams:
assert isinstance(generic_layer_attrs, GenericWeightedLayerAttributes)
dynamic_bn_input_op = ElasticInputWidthBatchNormOp(
max_width=generic_layer_attrs.get_num_filters(), node_name=node_name
max_width=get_num_filters_legacy(generic_layer_attrs), node_name=node_name
)
return UpdateBatchNormParams(dynamic_bn_input_op)

@staticmethod
def _create_dynamic_ln_input_op(generic_layer_attrs: BaseLayerAttributes, node_name: str) -> UpdateLayerNormParams:
assert isinstance(generic_layer_attrs, GenericWeightedLayerAttributes)
dynamic_ln_input_op = ElasticInputWidthLayerNormOp(
max_width=generic_layer_attrs.get_num_filters(), node_name=node_name
max_width=get_num_filters_legacy(generic_layer_attrs), node_name=node_name
)
return UpdateLayerNormParams(dynamic_ln_input_op)

Expand Down
6 changes: 4 additions & 2 deletions nncf/experimental/torch/sparsity/movement/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import nncf
from nncf.common.graph import NNCFNode
from nncf.common.graph.utils import get_bias_shape_legacy
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.experimental.torch.sparsity.movement.functions import binary_mask_by_threshold
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import CompressionParameter
Expand Down Expand Up @@ -167,7 +169,7 @@ def __init__(
self._importance_threshold = -math.inf
self._importance_regularization_factor = 0.0

weight_shape: List[int] = target_module_node.layer_attributes.get_weight_shape()
weight_shape: List[int] = get_weight_shape_legacy(target_module_node.layer_attributes)
assert len(weight_shape) == 2, "Unsupported module with weight shape not in 2D."
self.weight_ctx = BinaryMask(weight_shape)
self.sparse_factors = self._get_sparse_factors(weight_shape, sparse_cfg)
Expand All @@ -185,7 +187,7 @@ def __init__(
self.weight_ctx.binary_mask = self._calc_training_binary_mask()

if self.prune_bias:
bias_shape = target_module_node.layer_attributes.get_bias_shape()
bias_shape = get_bias_shape_legacy(target_module_node.layer_attributes)
self.bias_ctx = BinaryMask(bias_shape)
bias_importance_shape = weight_importance_shape[0]
self.bias_importance = CompressionParameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from nncf.common.graph.graph import NNCFNodeName
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.utils import get_bias_shape_legacy
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.logging import nncf_logger
from nncf.experimental.common.pruning.nodes_grouping import get_pruning_groups
from nncf.experimental.common.pruning.nodes_grouping import select_largest_groups
Expand Down Expand Up @@ -208,9 +210,9 @@ def gather_statistics_from_operand(self) -> StructuredMaskContextStatistics:
"""
node = self.sparsifier_operand.target_module_node
assert isinstance(node.layer_attributes, tuple(EXPECTED_NODE_LAYER_ATTRS))
weight_shape: Tuple[int, int] = tuple(node.layer_attributes.get_weight_shape())
weight_shape: Tuple[int, int] = tuple(get_weight_shape_legacy(node.layer_attributes))
bias_shape: Tuple[int] = (
(node.layer_attributes.get_bias_shape(),) if self.sparsifier_operand.prune_bias else (0,)
(get_bias_shape_legacy(node.layer_attributes),) if self.sparsifier_operand.prune_bias else (0,)
)

pruned_weight_shape = list(weight_shape)
Expand Down
12 changes: 7 additions & 5 deletions nncf/torch/quantization/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from nncf.common.graph.patterns.manager import TargetDevice
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.utils import get_first_nodes_of_type
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.hardware.config import HWConfig
from nncf.common.hardware.config import HWConfigType
from nncf.common.hardware.config import get_hw_config_type
Expand Down Expand Up @@ -603,8 +605,8 @@ def _get_minmax_values_for_quantizer_locations(
if qp.is_weight_quantization_point():
layer_attrs = target_node.layer_attributes
assert isinstance(layer_attrs, WeightedLayerAttributes)
input_shape = layer_attrs.get_weight_shape()
channel_idx = layer_attrs.get_target_dim_for_compression()
input_shape = get_weight_shape_legacy(layer_attrs)
channel_idx = get_target_dim_for_compression_legacy(layer_attrs)
else:
input_shape = target_model_graph.get_input_shape_for_insertion_point(qp.insertion_point)
channel_idx = 1 # channel dim for activations
Expand Down Expand Up @@ -773,10 +775,10 @@ def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup:
layer_attributes = target_node.layer_attributes
assert isinstance(layer_attributes, WeightedLayerAttributes)
scale_shape = get_scale_shape(
layer_attributes.get_weight_shape(),
get_weight_shape_legacy(layer_attributes),
is_weights=True,
per_channel=qconfig.per_channel,
channel_idx=layer_attributes.get_target_dim_for_compression(),
channel_idx=get_target_dim_for_compression_legacy(layer_attributes),
)
else:
input_shape = target_model_graph.get_input_shape_for_insertion_point(insertion_point)
Expand Down Expand Up @@ -1181,7 +1183,7 @@ def is_weights(ip: PTTargetPoint) -> bool:
)
module_node = target_model_graph.get_node_by_name(primary_ip.target_node_name)
layer_attributes = module_node.layer_attributes
input_shape = layer_attributes.get_weight_shape()
input_shape = get_weight_shape_legacy(layer_attributes)
self._quantizers_input_shapes[primary_qid] = tuple(input_shape)
else:
primary_qid = NonWeightQuantizerId(primary_ip.target_node_name, primary_ip.input_port_id)
Expand Down
6 changes: 4 additions & 2 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import nncf
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.quantization.initialization.range import RangeInitCollectorParams
from nncf.common.quantization.initialization.range import RangeInitConfig
from nncf.common.quantization.initialization.range import RangeInitParams
Expand Down Expand Up @@ -226,8 +228,8 @@ def get_all_scale_shapes_with_params(
module_node = target_nncf_graph.get_node_by_name(qp.insertion_point.target_node_name)
layer_attributes = module_node.layer_attributes
assert isinstance(layer_attributes, WeightedLayerAttributes)
input_shape = layer_attributes.get_weight_shape()
channel_idx = layer_attributes.get_target_dim_for_compression()
input_shape = get_weight_shape_legacy(layer_attributes)
channel_idx = get_target_dim_for_compression_legacy(layer_attributes)
else:
input_shape = target_nncf_graph.get_input_shape_for_insertion_point(qp.insertion_point)
channel_idx = 1 # channel dim for activations
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/sparsity/const/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Tuple

from nncf.common.graph import NNCFNode
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.sparsity.statistics import ConstSparsityStatistics
from nncf.common.statistics import NNCFStatistics
from nncf.common.utils.api_marker import api
Expand All @@ -26,7 +27,7 @@
@PT_COMPRESSION_ALGORITHMS.register("const_sparsity")
class ConstSparsityBuilder(BaseSparsityAlgoBuilder):
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))

def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
return ConstSparsityController(model, self._sparsified_module_info)
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/sparsity/magnitude/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nncf.api.compression import CompressionStage
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
from nncf.common.graph import NNCFNode
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.initialization.batchnorm_adaptation import BatchnormAdaptationAlgorithm
from nncf.common.schedulers import StubCompressionScheduler
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
Expand Down Expand Up @@ -44,7 +45,7 @@
@PT_COMPRESSION_ALGORITHMS.register("magnitude_sparsity")
class MagnitudeSparsityBuilder(BaseSparsityAlgoBuilder):
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))

def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
return MagnitudeSparsityController(model, self._sparsified_module_info, self.config)
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/sparsity/rb/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.api.compression import CompressionStage
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
from nncf.common.graph import NNCFNode
from nncf.common.graph.utils import get_weight_shape_legacy
from nncf.common.schedulers import StubCompressionScheduler
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
from nncf.common.sparsity.statistics import RBSparsityStatistics
Expand All @@ -44,7 +45,7 @@
class RBSparsityBuilder(BaseSparsityAlgoBuilder):
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
return RBSparsifyingWeight(
target_module_node.layer_attributes.get_weight_shape(),
get_weight_shape_legacy(target_module_node.layer_attributes),
frozen=False,
compression_lr_multiplier=compression_lr_multiplier,
)
Expand Down