Skip to content

Commit 4aec978

Browse files
committed
[NNCF] Add get_weight_shape_legacy function (openvinotoolkit#3249)
Changes: Moved and renamed backend-specific class method from common layer attributes in nncf/common/graph/layer_attributes.py to self-contained function in nncf/common/graph/utils.py - layer_attributes.get_weight_shape -> get_weight_shape_legacy(layer_attributes: WeightedLayerAttributes) Reason for changes: (openvinotoolkit#3249) Torch and Tensorflow backend-specific methods need to be removed from common layer attributes and all related calls need to be replaced by their corresponding legacy function calls
1 parent b37408f commit 4aec978

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

nncf/common/graph/utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,30 @@ def get_reduction_axes(
132132
for channel_axis in sorted(channel_axes, reverse=True):
133133
del reduction_axes[channel_axis]
134134
return tuple(reduction_axes)
135+
136+
137+
def get_weight_shape_legacy(
138+
layer_attributes: WeightedLayerAttributes) -> List[int]:
139+
"""
140+
Returns hard-coded weights shape layout only for Torch and Tensorflow models.
141+
142+
:param layer_attributes: layer attributes of NNCFNode.
143+
:return: weights shape layout.
144+
"""
145+
if isinstance(layer_attributes, GenericWeightedLayerAttributes):
146+
return layer_attributes.weight_shape
147+
148+
if isinstance(layer_attributes, LinearLayerAttributes):
149+
return[layer_attributes.out_features, layer_attributes.in_features]
150+
151+
if isinstance(layer_attributes, ConvolutionLayerAttributes):
152+
if not layer_attributes.transpose:
153+
return [layer_attributes.out_channels,
154+
layer_attributes.in_channels // layer_attributes.groups,
155+
*layer_attributes.kernel_size]
156+
return [layer_attributes.in_channels,
157+
layer_attributes.out_channels // layer_attributes.groups,
158+
*layer_attributes.kernel_size]
159+
160+
if isinstance(layer_attributes, GroupNormLayerAttributes):
161+
return [layer_attributes.num_channels]

0 commit comments

Comments
 (0)