14
14
15
15
from nncf .common .graph import NNCFGraph
16
16
from nncf .common .graph import NNCFNode
17
- from nncf .common .graph .layer_attributes import WeightedLayerAttributes
18
- from nncf .common .graph .layer_attributes import GenericWeightedLayerAttributes
19
- from nncf .common .graph .layer_attributes import LinearLayerAttributes
20
17
from nncf .common .graph .layer_attributes import ConvolutionLayerAttributes
18
+ from nncf .common .graph .layer_attributes import GenericWeightedLayerAttributes
21
19
from nncf .common .graph .layer_attributes import GroupNormLayerAttributes
20
+ from nncf .common .graph .layer_attributes import LinearLayerAttributes
21
+ from nncf .common .graph .layer_attributes import WeightedLayerAttributes
22
22
from nncf .common .graph .operator_metatypes import OperatorMetatype
23
23
from nncf .common .logging import nncf_logger
24
24
from nncf .common .pruning .utils import traverse_function
@@ -139,8 +139,7 @@ def get_reduction_axes(
139
139
return tuple (reduction_axes )
140
140
141
141
142
- def get_weight_shape_legacy (
143
- layer_attributes : WeightedLayerAttributes ) -> List [int ]:
142
+ def get_weight_shape_legacy (layer_attributes : WeightedLayerAttributes ) -> List [int ]:
144
143
"""
145
144
Returns hard-coded weights shape layout only for Torch and Tensorflow models.
146
145
@@ -151,31 +150,31 @@ def get_weight_shape_legacy(
151
150
return layer_attributes .weight_shape
152
151
153
152
if isinstance (layer_attributes , LinearLayerAttributes ):
154
- return [layer_attributes .out_features , layer_attributes .in_features ]
153
+ return [layer_attributes .out_features , layer_attributes .in_features ]
155
154
156
155
if isinstance (layer_attributes , ConvolutionLayerAttributes ):
157
156
if not layer_attributes .transpose :
158
- return [layer_attributes .out_channels ,
159
- layer_attributes .in_channels // layer_attributes .groups ,
160
- * layer_attributes .kernel_size ]
161
- return [layer_attributes .in_channels ,
162
- layer_attributes .out_channels // layer_attributes .groups ,
157
+ return [
158
+ layer_attributes .out_channels ,
159
+ layer_attributes .in_channels // layer_attributes .groups ,
163
160
* layer_attributes .kernel_size ]
161
+ return [
162
+ layer_attributes .in_channels ,
163
+ layer_attributes .out_channels // layer_attributes .groups ,
164
+ * layer_attributes .kernel_size ]
164
165
165
166
if isinstance (layer_attributes , GroupNormLayerAttributes ):
166
167
return [layer_attributes .num_channels ]
167
168
168
169
169
- def get_target_dim_for_compression_legacy (
170
- layer_attributes : WeightedLayerAttributes ) -> int :
170
+ def get_target_dim_for_compression_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
171
171
"""
172
172
Returns hard-coded target dim for compression only for Torch and Tensorflow models.
173
173
174
174
:param layer_attributes: layer attributes of NNCFNode.
175
175
:return: target dim for compression.
176
176
"""
177
- if isinstance (layer_attributes , (GenericWeightedLayerAttributes ,
178
- LinearLayerAttributes , GroupNormLayerAttributes )):
177
+ if isinstance (layer_attributes , (GenericWeightedLayerAttributes , LinearLayerAttributes , GroupNormLayerAttributes )):
179
178
return 0
180
179
181
180
if isinstance (layer_attributes , ConvolutionLayerAttributes ):
@@ -185,8 +184,7 @@ def get_target_dim_for_compression_legacy(
185
184
return 0
186
185
187
186
188
- def get_bias_shape_legacy (
189
- layer_attributes : WeightedLayerAttributes ) -> int :
187
+ def get_bias_shape_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
190
188
"""
191
189
Returns hard-coded bias shape only for Torch and Tensorflow models.
192
190
@@ -195,3 +193,15 @@ def get_bias_shape_legacy(
195
193
"""
196
194
if isinstance (layer_attributes , LinearLayerAttributes ):
197
195
return layer_attributes .out_features if layer_attributes .with_bias is True else 0
196
+
197
+
198
+ def get_num_filters_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
199
+ """
200
+ Returns hard-coded number of filters only for Torch and Tensorflow models.
201
+
202
+ :param layer_attributes: layer attributes of NNCFNode.
203
+ :return: number of filters.
204
+ """
205
+ weight_shape = layer_attributes .get_weight_shape_legacy ()
206
+ return weight_shape [layer_attributes .get_target_dim_for_compression_legacy ()]
207
+
0 commit comments