Skip to content

Commit dde704a

Browse files
authored
Update utils.py
1 parent 40a9561 commit dde704a

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

nncf/common/graph/utils.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
from nncf.common.graph import NNCFGraph
1616
from nncf.common.graph import NNCFNode
17-
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
18-
from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes
19-
from nncf.common.graph.layer_attributes import LinearLayerAttributes
2017
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
18+
from nncf.common.graph.layer_attributes import GenericWeightedLayerAttributes
2119
from nncf.common.graph.layer_attributes import GroupNormLayerAttributes
20+
from nncf.common.graph.layer_attributes import LinearLayerAttributes
21+
from nncf.common.graph.layer_attributes import WeightedLayerAttributes
2222
from nncf.common.graph.operator_metatypes import OperatorMetatype
2323
from nncf.common.logging import nncf_logger
2424
from nncf.common.pruning.utils import traverse_function
@@ -139,8 +139,7 @@ def get_reduction_axes(
139139
return tuple(reduction_axes)
140140

141141

142-
def get_weight_shape_legacy(
143-
layer_attributes: WeightedLayerAttributes) -> List[int]:
142+
def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[int]:
144143
"""
145144
Returns hard-coded weights shape layout only for Torch and Tensorflow models.
146145
@@ -151,31 +150,31 @@ def get_weight_shape_legacy(
151150
return layer_attributes.weight_shape
152151

153152
if isinstance(layer_attributes, LinearLayerAttributes):
154-
return[layer_attributes.out_features, layer_attributes.in_features]
153+
return [layer_attributes.out_features, layer_attributes.in_features]
155154

156155
if isinstance(layer_attributes, ConvolutionLayerAttributes):
157156
if not layer_attributes.transpose:
158-
return [layer_attributes.out_channels,
159-
layer_attributes.in_channels // layer_attributes.groups,
160-
*layer_attributes.kernel_size]
161-
return [layer_attributes.in_channels,
162-
layer_attributes.out_channels // layer_attributes.groups,
157+
return [
158+
layer_attributes.out_channels,
159+
layer_attributes.in_channels // layer_attributes.groups,
163160
*layer_attributes.kernel_size]
161+
return [
162+
layer_attributes.in_channels,
163+
layer_attributes.out_channels // layer_attributes.groups,
164+
*layer_attributes.kernel_size]
164165

165166
if isinstance(layer_attributes, GroupNormLayerAttributes):
166167
return [layer_attributes.num_channels]
167168

168169

169-
def get_target_dim_for_compression_legacy(
170-
layer_attributes: WeightedLayerAttributes) -> int:
170+
def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttributes) -> int:
171171
"""
172172
Returns hard-coded target dim for compression only for Torch and Tensorflow models.
173173
174174
:param layer_attributes: layer attributes of NNCFNode.
175175
:return: target dim for compression.
176176
"""
177-
if isinstance(layer_attributes, (GenericWeightedLayerAttributes,
178-
LinearLayerAttributes, GroupNormLayerAttributes)):
177+
if isinstance(layer_attributes, (GenericWeightedLayerAttributes, LinearLayerAttributes, GroupNormLayerAttributes)):
179178
return 0
180179

181180
if isinstance(layer_attributes, ConvolutionLayerAttributes):
@@ -185,8 +184,7 @@ def get_target_dim_for_compression_legacy(
185184
return 0
186185

187186

188-
def get_bias_shape_legacy(
189-
layer_attributes: WeightedLayerAttributes) -> int:
187+
def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
190188
"""
191189
Returns hard-coded bias shape only for Torch and Tensorflow models.
192190
@@ -195,3 +193,15 @@ def get_bias_shape_legacy(
195193
"""
196194
if isinstance(layer_attributes, LinearLayerAttributes):
197195
return layer_attributes.out_features if layer_attributes.with_bias is True else 0
196+
197+
198+
def get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) -> int:
199+
"""
200+
Returns hard-coded number of filters only for Torch and Tensorflow models.
201+
202+
:param layer_attributes: layer attributes of NNCFNode.
203+
:return: number of filters.
204+
"""
205+
weight_shape = layer_attributes.get_weight_shape_legacy()
206+
return weight_shape[layer_attributes.get_target_dim_for_compression_legacy()]
207+

0 commit comments

Comments
 (0)