Skip to content

Commit 7996e73

Browse files
authored
[NNCF] (#3249) Remove backend-specific methods from common layer attributes (#3287)
### Changes 1. Moved and renamed backend-specific class methods from common layer attributes in nncf/common/graph/layer_attributes.py to self-contained functions in nncf/common/graph/utils.py - layer_attributes.get_bias_shape -> get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) - layer_attributes.get_num_filters -> get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) - layer_attributes.get_target_dim_for_compression -> get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttributes) - layer_attributes.get_weight_shape -> get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) 2. Calls to above class methods were replaced with calls to their corresponding legacy functions in the following locations - /nncf/experimental/common/ folder - pruning/nodes_grouping.py - /nncf/experimental/tensorflow/ folder - quantization/algorithm.py - /nncf/experimental/torch/ folder - nas/bootstrapNAS/elasticity/elastic_width.py - sparsity/movement/layers.py - sparsity/movement/structured_mask_handler.py - /nncf/torch/ folder - quantization/algo.py - quantization/init_range.py - sparsity/const/algo.py - sparsity/magnitude/algo.py - sparsity/rb/algo.py ### Reason for changes Torch and Tensorflow backend-specific methods need to be removed from common layer attributes and all related calls need to be replaced by their corresponding legacy function calls (resolves #3249)
1 parent d3dd9dd commit 7996e73

File tree

12 files changed

+112
-63
lines changed

12 files changed

+112
-63
lines changed

nncf/common/graph/layer_attributes.py

-45
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# limitations under the License.
1111

1212
from abc import ABC
13-
from abc import abstractmethod
1413
from dataclasses import dataclass
1514
from enum import Enum
1615
from typing import Any, List, Optional, Tuple, Union
@@ -66,18 +65,6 @@ def __init__(self, weight_requires_grad: bool, dtype: Dtype = Dtype.FLOAT, with_
6665
self.dtype = dtype
6766
self.with_bias = with_bias
6867

69-
@abstractmethod
70-
def get_weight_shape(self) -> List[int]:
71-
pass
72-
73-
def get_num_filters(self) -> int:
74-
weight_shape = self.get_weight_shape()
75-
return weight_shape[self.get_target_dim_for_compression()]
76-
77-
@abstractmethod
78-
def get_target_dim_for_compression(self) -> int:
79-
pass
80-
8168

8269
class GenericWeightedLayerAttributes(WeightedLayerAttributes):
8370
"""
@@ -103,12 +90,6 @@ def __init__(
10390
self.weight_shape = weight_shape
10491
self.filter_dimension_idx = filter_dimension_idx
10592

106-
def get_weight_shape(self) -> List[int]:
107-
return self.weight_shape
108-
109-
def get_target_dim_for_compression(self) -> int:
110-
return 0
111-
11293

11394
class LinearLayerAttributes(WeightedLayerAttributes):
11495
def __init__(
@@ -129,15 +110,6 @@ def __init__(
129110
self.in_features = in_features
130111
self.out_features = out_features
131112

132-
def get_weight_shape(self) -> List[int]:
133-
return [self.out_features, self.in_features]
134-
135-
def get_bias_shape(self) -> int:
136-
return self.out_features if self.with_bias is True else 0
137-
138-
def get_target_dim_for_compression(self) -> int:
139-
return 0
140-
141113

142114
class ConvolutionLayerAttributes(WeightedLayerAttributes):
143115
def __init__(
@@ -179,17 +151,6 @@ def __init__(
179151
self.padding_values = padding_values
180152
self.output_padding_values = output_padding_values
181153

182-
def get_weight_shape(self) -> List[int]:
183-
if not self.transpose:
184-
return [self.out_channels, self.in_channels // self.groups, *self.kernel_size]
185-
return [self.in_channels, self.out_channels // self.groups, *self.kernel_size]
186-
187-
def get_target_dim_for_compression(self) -> int:
188-
# Always quantize per each "out" channel
189-
if self.transpose:
190-
return 1
191-
return 0
192-
193154

194155
class GroupNormLayerAttributes(WeightedLayerAttributes):
195156
def __init__(self, weight_requires_grad: bool, num_channels: int, num_groups: int):
@@ -204,12 +165,6 @@ def __init__(self, weight_requires_grad: bool, num_channels: int, num_groups: in
204165
self.num_channels = num_channels
205166
self.num_groups = num_groups
206167

207-
def get_weight_shape(self) -> List[int]:
208-
return [self.num_channels]
209-
210-
def get_target_dim_for_compression(self) -> int:
211-
return 0
212-
213168

214169
@dataclass
215170
class ReshapeLayerAttributes(BaseLayerAttributes):

nncf/common/graph/utils.py

+80
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515
from nncf.common.graph import NNCFGraph
1616
from nncf.common.graph import NNCFNode
17+
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
18+
from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes
19+
from nncf.common.graph.layer_attributes import GroupNormLayerAttributes
20+
from nncf.common.graph.layer_attributes import LinearLayerAttributes
21+
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
1722
from nncf.common.graph.operator_metatypes import OperatorMetatype
1823
from nncf.common.logging import nncf_logger
1924
from nncf.common.pruning.utils import traverse_function
@@ -132,3 +137,78 @@ def get_reduction_axes(
132137
for channel_axis in sorted(channel_axes, reverse=True):
133138
del reduction_axes[channel_axis]
134139
return tuple(reduction_axes)
140+
141+
142+
def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[int]:
143+
"""
144+
Returns hard-coded weights shape layout for the given layer attributes.
145+
Applicable only for eager PyTorch and Tensorflow models.
146+
147+
:param layer_attributes: Layer attributes of a NNCFNode.
148+
:return: Weights shape layout.
149+
"""
150+
if isinstance(layer_attributes, LinearLayerAttributes):
151+
return [layer_attributes.out_features, layer_attributes.in_features]
152+
153+
if isinstance(layer_attributes, ConvolutionLayerAttributes):
154+
if not layer_attributes.transpose:
155+
return [
156+
layer_attributes.out_channels,
157+
layer_attributes.in_channels // layer_attributes.groups,
158+
*layer_attributes.kernel_size,
159+
]
160+
return [
161+
layer_attributes.in_channels,
162+
layer_attributes.out_channels // layer_attributes.groups,
163+
*layer_attributes.kernel_size,
164+
]
165+
166+
if isinstance(layer_attributes, GroupNormLayerAttributes):
167+
return [layer_attributes.num_channels]
168+
169+
assert isinstance(layer_attributes, GenericWeightedLayerAttributes)
170+
return layer_attributes.weight_shape
171+
172+
173+
def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttributes) -> int:
174+
"""
175+
Returns hard-coded target dim for compression for the given layer attributes.
176+
Applicable only for eager PyTorch and Tensorflow models.
177+
178+
:param layer_attributes: Layer attributes of a NNCFNode.
179+
:return: Target dim for compression.
180+
"""
181+
if isinstance(layer_attributes, ConvolutionLayerAttributes):
182+
# Always quantize per each "out" channel
183+
return 1 if layer_attributes.transpose else 0
184+
185+
else:
186+
assert isinstance(
187+
layer_attributes, (GenericWeightedLayerAttributes, LinearLayerAttributes, GroupNormLayerAttributes)
188+
)
189+
return 0
190+
191+
192+
def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
193+
"""
194+
Returns hard-coded bias shape for the given linear layer attributes.
195+
Applicable only for eager PyTorch and Tensorflow models.
196+
197+
:param layer_attributes: Linear layer attributes of a NNCFNode.
198+
:return: Correspondent bias shape.
199+
"""
200+
assert isinstance(layer_attributes, LinearLayerAttributes)
201+
return layer_attributes.out_features if layer_attributes.with_bias is True else 0
202+
203+
204+
def get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) -> int:
205+
"""
206+
Returns hard-coded number of filters for the given layer attribues.
207+
Applicable only for eager PyTorch and Tensorflow models.
208+
209+
:param layer_attributes: Layer attributes of a NNCFNode.
210+
:return: Correspondent number of filters.
211+
"""
212+
assert isinstance(layer_attributes, WeightedLayerAttributes)
213+
weight_shape = get_weight_shape_legacy(layer_attributes)
214+
return weight_shape[get_target_dim_for_compression_legacy(layer_attributes)]

nncf/experimental/common/pruning/nodes_grouping.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.common.graph.graph import NNCFNode
1919
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
2020
from nncf.common.graph.layer_attributes import LinearLayerAttributes
21+
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
2122
from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm
2223
from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry
2324
from nncf.experimental.common.graph.netron import save_for_netron
@@ -76,7 +77,7 @@ def get_pruning_groups(
7677
roots = {}
7778
for node in all_nodes_to_prune:
7879
assert isinstance(node.layer_attributes, (LinearLayerAttributes, ConvolutionLayerAttributes))
79-
pruning_dim = node.layer_attributes.get_target_dim_for_compression()
80+
pruning_dim = get_target_dim_for_compression_legacy(node.layer_attributes)
8081
output_tensors_shapes = [x.tensor_shape for x in graph.get_output_edges(node)]
8182
assert not len(set(output_tensors_shapes)) > 1, node.node_name
8283
output_tensors_shape = output_tensors_shapes[0]

nncf/experimental/tensorflow/quantization/algorithm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.common.graph.transformations.commands import TargetType
1919
from nncf.common.graph.transformations.commands import TransformationPriority
2020
from nncf.common.graph.utils import get_first_nodes_of_type
21+
from nncf.common.graph.utils import get_weight_shape_legacy
2122
from nncf.common.logging import nncf_logger
2223
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
2324
from nncf.common.quantization.quantizer_setup import QuantizationPointId
@@ -133,7 +134,7 @@ def _get_tensor_specs(
133134
assert len(metatype.weight_definitions) == 1
134135

135136
channel_axes = metatype.weight_definitions[0].channel_axes
136-
weight_shape = node.layer_attributes.get_weight_shape()
137+
weight_shape = get_weight_shape_legacy(node.layer_attributes)
137138
tensor_specs.append((weight_shape, channel_axes))
138139
else:
139140
data_format = node.layer_attributes.get_data_format()

nncf/experimental/torch/nas/bootstrapNAS/elasticity/elastic_width.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from nncf.common.graph.transformations.commands import TargetType
2828
from nncf.common.graph.transformations.commands import TransformationCommand
2929
from nncf.common.graph.transformations.commands import TransformationPriority
30+
from nncf.common.graph.utils import get_num_filters_legacy
3031
from nncf.common.logging import nncf_logger
3132
from nncf.common.pruning.clusterization import Cluster
3233
from nncf.common.pruning.clusterization import Clusterization
@@ -1204,15 +1205,15 @@ def _create_dynamic_dw_conv_input_op(conv_layer_attrs: BaseLayerAttributes, node
12041205
def _create_dynamic_bn_input_op(generic_layer_attrs: BaseLayerAttributes, node_name: str) -> UpdateBatchNormParams:
12051206
assert isinstance(generic_layer_attrs, GenericWeightedLayerAttributes)
12061207
dynamic_bn_input_op = ElasticInputWidthBatchNormOp(
1207-
max_width=generic_layer_attrs.get_num_filters(), node_name=node_name
1208+
max_width=get_num_filters_legacy(generic_layer_attrs), node_name=node_name
12081209
)
12091210
return UpdateBatchNormParams(dynamic_bn_input_op)
12101211

12111212
@staticmethod
12121213
def _create_dynamic_ln_input_op(generic_layer_attrs: BaseLayerAttributes, node_name: str) -> UpdateLayerNormParams:
12131214
assert isinstance(generic_layer_attrs, GenericWeightedLayerAttributes)
12141215
dynamic_ln_input_op = ElasticInputWidthLayerNormOp(
1215-
max_width=generic_layer_attrs.get_num_filters(), node_name=node_name
1216+
max_width=get_num_filters_legacy(generic_layer_attrs), node_name=node_name
12161217
)
12171218
return UpdateLayerNormParams(dynamic_ln_input_op)
12181219

nncf/experimental/torch/sparsity/movement/layers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import nncf
2020
from nncf.common.graph import NNCFNode
21+
from nncf.common.graph.utils import get_bias_shape_legacy
22+
from nncf.common.graph.utils import get_weight_shape_legacy
2123
from nncf.experimental.torch.sparsity.movement.functions import binary_mask_by_threshold
2224
from nncf.torch.layer_utils import COMPRESSION_MODULES
2325
from nncf.torch.layer_utils import CompressionParameter
@@ -167,7 +169,7 @@ def __init__(
167169
self._importance_threshold = -math.inf
168170
self._importance_regularization_factor = 0.0
169171

170-
weight_shape: List[int] = target_module_node.layer_attributes.get_weight_shape()
172+
weight_shape: List[int] = get_weight_shape_legacy(target_module_node.layer_attributes)
171173
assert len(weight_shape) == 2, "Unsupported module with weight shape not in 2D."
172174
self.weight_ctx = BinaryMask(weight_shape)
173175
self.sparse_factors = self._get_sparse_factors(weight_shape, sparse_cfg)
@@ -185,7 +187,7 @@ def __init__(
185187
self.weight_ctx.binary_mask = self._calc_training_binary_mask()
186188

187189
if self.prune_bias:
188-
bias_shape = target_module_node.layer_attributes.get_bias_shape()
190+
bias_shape = get_bias_shape_legacy(target_module_node.layer_attributes)
189191
self.bias_ctx = BinaryMask(bias_shape)
190192
bias_importance_shape = weight_importance_shape[0]
191193
self.bias_importance = CompressionParameter(

nncf/experimental/torch/sparsity/movement/structured_mask_handler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from nncf.common.graph.graph import NNCFNodeName
2020
from nncf.common.graph.layer_attributes import LinearLayerAttributes
21+
from nncf.common.graph.utils import get_bias_shape_legacy
22+
from nncf.common.graph.utils import get_weight_shape_legacy
2123
from nncf.common.logging import nncf_logger
2224
from nncf.experimental.common.pruning.nodes_grouping import get_pruning_groups
2325
from nncf.experimental.common.pruning.nodes_grouping import select_largest_groups
@@ -208,9 +210,9 @@ def gather_statistics_from_operand(self) -> StructuredMaskContextStatistics:
208210
"""
209211
node = self.sparsifier_operand.target_module_node
210212
assert isinstance(node.layer_attributes, tuple(EXPECTED_NODE_LAYER_ATTRS))
211-
weight_shape: Tuple[int, int] = tuple(node.layer_attributes.get_weight_shape())
213+
weight_shape: Tuple[int, int] = tuple(get_weight_shape_legacy(node.layer_attributes))
212214
bias_shape: Tuple[int] = (
213-
(node.layer_attributes.get_bias_shape(),) if self.sparsifier_operand.prune_bias else (0,)
215+
(get_bias_shape_legacy(node.layer_attributes),) if self.sparsifier_operand.prune_bias else (0,)
214216
)
215217

216218
pruned_weight_shape = list(weight_shape)

nncf/torch/quantization/algo.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from nncf.common.graph.patterns.manager import TargetDevice
3636
from nncf.common.graph.transformations.commands import TargetType
3737
from nncf.common.graph.utils import get_first_nodes_of_type
38+
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
39+
from nncf.common.graph.utils import get_weight_shape_legacy
3840
from nncf.common.hardware.config import HWConfig
3941
from nncf.common.hardware.config import HWConfigType
4042
from nncf.common.hardware.config import get_hw_config_type
@@ -603,8 +605,8 @@ def _get_minmax_values_for_quantizer_locations(
603605
if qp.is_weight_quantization_point():
604606
layer_attrs = target_node.layer_attributes
605607
assert isinstance(layer_attrs, WeightedLayerAttributes)
606-
input_shape = layer_attrs.get_weight_shape()
607-
channel_idx = layer_attrs.get_target_dim_for_compression()
608+
input_shape = get_weight_shape_legacy(layer_attrs)
609+
channel_idx = get_target_dim_for_compression_legacy(layer_attrs)
608610
else:
609611
input_shape = target_model_graph.get_input_shape_for_insertion_point(qp.insertion_point)
610612
channel_idx = 1 # channel dim for activations
@@ -773,10 +775,10 @@ def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup:
773775
layer_attributes = target_node.layer_attributes
774776
assert isinstance(layer_attributes, WeightedLayerAttributes)
775777
scale_shape = get_scale_shape(
776-
layer_attributes.get_weight_shape(),
778+
get_weight_shape_legacy(layer_attributes),
777779
is_weights=True,
778780
per_channel=qconfig.per_channel,
779-
channel_idx=layer_attributes.get_target_dim_for_compression(),
781+
channel_idx=get_target_dim_for_compression_legacy(layer_attributes),
780782
)
781783
else:
782784
input_shape = target_model_graph.get_input_shape_for_insertion_point(insertion_point)
@@ -1181,7 +1183,7 @@ def is_weights(ip: PTTargetPoint) -> bool:
11811183
)
11821184
module_node = target_model_graph.get_node_by_name(primary_ip.target_node_name)
11831185
layer_attributes = module_node.layer_attributes
1184-
input_shape = layer_attributes.get_weight_shape()
1186+
input_shape = get_weight_shape_legacy(layer_attributes)
11851187
self._quantizers_input_shapes[primary_qid] = tuple(input_shape)
11861188
else:
11871189
primary_qid = NonWeightQuantizerId(primary_ip.target_node_name, primary_ip.input_port_id)

nncf/torch/quantization/init_range.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import nncf
2020
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
21+
from nncf.common.graph.utils import get_target_dim_for_compression_legacy
22+
from nncf.common.graph.utils import get_weight_shape_legacy
2123
from nncf.common.quantization.initialization.range import RangeInitCollectorParams
2224
from nncf.common.quantization.initialization.range import RangeInitConfig
2325
from nncf.common.quantization.initialization.range import RangeInitParams
@@ -226,8 +228,8 @@ def get_all_scale_shapes_with_params(
226228
module_node = target_nncf_graph.get_node_by_name(qp.insertion_point.target_node_name)
227229
layer_attributes = module_node.layer_attributes
228230
assert isinstance(layer_attributes, WeightedLayerAttributes)
229-
input_shape = layer_attributes.get_weight_shape()
230-
channel_idx = layer_attributes.get_target_dim_for_compression()
231+
input_shape = get_weight_shape_legacy(layer_attributes)
232+
channel_idx = get_target_dim_for_compression_legacy(layer_attributes)
231233
else:
232234
input_shape = target_nncf_graph.get_input_shape_for_insertion_point(qp.insertion_point)
233235
channel_idx = 1 # channel dim for activations

nncf/torch/sparsity/const/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Tuple
1212

1313
from nncf.common.graph import NNCFNode
14+
from nncf.common.graph.utils import get_weight_shape_legacy
1415
from nncf.common.sparsity.statistics import ConstSparsityStatistics
1516
from nncf.common.statistics import NNCFStatistics
1617
from nncf.common.utils.api_marker import api
@@ -26,7 +27,7 @@
2627
@PT_COMPRESSION_ALGORITHMS.register("const_sparsity")
2728
class ConstSparsityBuilder(BaseSparsityAlgoBuilder):
2829
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
29-
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
30+
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))
3031

3132
def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
3233
return ConstSparsityController(model, self._sparsified_module_info)

nncf/torch/sparsity/magnitude/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf.api.compression import CompressionStage
1818
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
1919
from nncf.common.graph import NNCFNode
20+
from nncf.common.graph.utils import get_weight_shape_legacy
2021
from nncf.common.initialization.batchnorm_adaptation import BatchnormAdaptationAlgorithm
2122
from nncf.common.schedulers import StubCompressionScheduler
2223
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
@@ -44,7 +45,7 @@
4445
@PT_COMPRESSION_ALGORITHMS.register("magnitude_sparsity")
4546
class MagnitudeSparsityBuilder(BaseSparsityAlgoBuilder):
4647
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
47-
return BinaryMask(target_module_node.layer_attributes.get_weight_shape())
48+
return BinaryMask(get_weight_shape_legacy(target_module_node.layer_attributes))
4849

4950
def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
5051
return MagnitudeSparsityController(model, self._sparsified_module_info, self.config)

nncf/torch/sparsity/rb/algo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.api.compression import CompressionStage
1919
from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
2020
from nncf.common.graph import NNCFNode
21+
from nncf.common.graph.utils import get_weight_shape_legacy
2122
from nncf.common.schedulers import StubCompressionScheduler
2223
from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
2324
from nncf.common.sparsity.statistics import RBSparsityStatistics
@@ -44,7 +45,7 @@
4445
class RBSparsityBuilder(BaseSparsityAlgoBuilder):
4546
def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
4647
return RBSparsifyingWeight(
47-
target_module_node.layer_attributes.get_weight_shape(),
48+
get_weight_shape_legacy(target_module_node.layer_attributes),
4849
frozen=False,
4950
compression_lr_multiplier=compression_lr_multiplier,
5051
)

0 commit comments

Comments
 (0)