Skip to content

Commit 2c888b7

Browse files
committed
mypy
1 parent 7009c42 commit 2c888b7

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

nncf/common/graph/utils.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,10 @@ def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[i
146146
:param layer_attributes: layer attributes of NNCFNode.
147147
:return: weights shape layout.
148148
"""
149-
if isinstance(layer_attributes, GenericWeightedLayerAttributes):
150-
return layer_attributes.weight_shape
151-
152149
if isinstance(layer_attributes, LinearLayerAttributes):
153150
return [layer_attributes.out_features, layer_attributes.in_features]
154151

155-
if isinstance(layer_attributes, ConvolutionLayerAttributes):
152+
elif isinstance(layer_attributes, ConvolutionLayerAttributes):
156153
if not layer_attributes.transpose:
157154
return [
158155
layer_attributes.out_channels,
@@ -165,9 +162,13 @@ def get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) -> List[i
165162
*layer_attributes.kernel_size,
166163
]
167164

168-
if isinstance(layer_attributes, GroupNormLayerAttributes):
165+
elif isinstance(layer_attributes, GroupNormLayerAttributes):
169166
return [layer_attributes.num_channels]
170167

168+
else:
169+
assert isinstance(layer_attributes, GenericWeightedLayerAttributes):
170+
return layer_attributes.weight_shape
171+
171172

172173
def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttributes) -> int:
173174
"""
@@ -176,15 +177,17 @@ def get_target_dim_for_compression_legacy(layer_attributes: WeightedLayerAttribu
176177
:param layer_attributes: layer attributes of NNCFNode.
177178
:return: target dim for compression.
178179
"""
179-
if isinstance(layer_attributes, (GenericWeightedLayerAttributes, LinearLayerAttributes, GroupNormLayerAttributes)):
180-
return 0
181-
182180
if isinstance(layer_attributes, ConvolutionLayerAttributes):
183181
# Always quantize per each "out" channel
184182
if layer_attributes.transpose:
185183
return 1
186184
return 0
187185

186+
else:
187+
assert isinstance(layer_attributes, (GenericWeightedLayerAttributes, LinearLayerAttributes, GroupNormLayerAttributes)):
188+
return 0
189+
190+
188191

189192
def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
190193
"""
@@ -193,8 +196,11 @@ def get_bias_shape_legacy(layer_attributes: WeightedLayerAttributes) -> int:
193196
:param layer_attributes: layer attributes of NNCFNode.
194197
:return: bias shape.
195198
"""
196-
if isinstance(layer_attributes, LinearLayerAttributes):
197-
return layer_attributes.out_features if layer_attributes.with_bias is True else 0
199+
if isinstance(layer_attributes, LinearLayerAttributes) and layer_attributes.with_bias:
200+
return layer_attributes.out_features
201+
202+
else:
203+
return 0
198204

199205

200206
def get_num_filters_legacy(layer_attributes: WeightedLayerAttributes) -> int:

0 commit comments

Comments
 (0)