@@ -146,13 +146,10 @@ def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[i
146
146
:param layer_attributes: layer attributes of NNCFNode.
147
147
:return: weights shape layout.
148
148
"""
149
- if isinstance (layer_attributes , GenericWeightedLayerAttributes ):
150
- return layer_attributes .weight_shape
151
-
152
149
if isinstance (layer_attributes , LinearLayerAttributes ):
153
150
return [layer_attributes .out_features , layer_attributes .in_features ]
154
151
155
- if isinstance (layer_attributes , ConvolutionLayerAttributes ):
152
+ elif isinstance (layer_attributes , ConvolutionLayerAttributes ):
156
153
if not layer_attributes .transpose :
157
154
return [
158
155
layer_attributes .out_channels ,
@@ -165,9 +162,13 @@ def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[i
165
162
* layer_attributes .kernel_size ,
166
163
]
167
164
168
- if isinstance (layer_attributes , GroupNormLayerAttributes ):
165
+ elif isinstance (layer_attributes , GroupNormLayerAttributes ):
169
166
return [layer_attributes .num_channels ]
170
167
168
+ else :
169
+ assert isinstance (layer_attributes , GenericWeightedLayerAttributes ):
170
+ return layer_attributes .weight_shape
171
+
171
172
172
173
def get_target_dim_for_compression_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
173
174
"""
@@ -176,15 +177,17 @@ def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttribu
176
177
:param layer_attributes: layer attributes of NNCFNode.
177
178
:return: target dim for compression.
178
179
"""
179
- if isinstance (layer_attributes , (GenericWeightedLayerAttributes , LinearLayerAttributes , GroupNormLayerAttributes )):
180
- return 0
181
-
182
180
if isinstance (layer_attributes , ConvolutionLayerAttributes ):
183
181
# Always quantize per each "out" channel
184
182
if layer_attributes .transpose :
185
183
return 1
186
184
return 0
187
185
186
+ else :
187
+ assert isinstance (layer_attributes , (GenericWeightedLayerAttributes , LinearLayerAttributes , GroupNormLayerAttributes )):
188
+ return 0
189
+
190
+
188
191
189
192
def get_bias_shape_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
190
193
"""
@@ -193,8 +196,11 @@ def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
193
196
:param layer_attributes: layer attributes of NNCFNode.
194
197
:return: bias shape.
195
198
"""
196
- if isinstance (layer_attributes , LinearLayerAttributes ):
197
- return layer_attributes .out_features if layer_attributes .with_bias is True else 0
199
+ if isinstance (layer_attributes , LinearLayerAttributes ) and layer_attributes .with_bias :
200
+ return layer_attributes .out_features
201
+
202
+ else :
203
+ return 0
198
204
199
205
200
206
def get_num_filters_legacy (layer_attributes : WeightedLayerAttributes ) -> int :
0 commit comments