Skip to content

Commit 2872f27

Browse files
authored
[PT][WC] Add WC tests for functional models (#2446)
### Changes Fixed bug with `channel_idx` for Torch Conv metatypes Fixed bug in `get_module_by_name()` if base model has weights as a `torch.nn.Parameter` ### Reason for changes <!--- Why should the change be applied --> ### Related tickets 124822 ### Tests Added `test_compress_weights_functional_model`
1 parent 1391667 commit 2872f27

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

nncf/quantization/algorithms/weight_compression/torch_backend.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,16 @@
4040

4141
def split_weight_name(weight_name: str) -> Tuple[str, str]:
4242
index = weight_name.rfind(".")
43+
if index == -1:
44+
return str(), weight_name
4345
module_name = weight_name[:index]
4446
weight_attr_name = weight_name[index + 1 :]
4547
return module_name, weight_attr_name
4648

4749

4850
def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module:
51+
if not module_name:
52+
return model
4953
curr_module = model
5054
for name in module_name.split("."):
5155
for child_name, child_module in curr_module.named_children():
@@ -161,8 +165,12 @@ def get_channel_agnostic_reduction_axes(
161165
elif weight_port_id == 2:
162166
reduction_axes = [max(0, ndims - 2)]
163167
elif node_with_weight.metatype in PTWeightCompressionAlgoBackend.CONVOLUTION_METATYPES:
164-
layer_attributes = node_with_weight.layer_attributes
165-
channel_idx = layer_attributes.get_target_dim_for_compression()
168+
channel_idx = (
169+
1
170+
if node_with_weight.metatype
171+
in [om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype]
172+
else 0
173+
)
166174
reduction_axes = [i for i in range(ndims) if i != channel_idx]
167175
return tuple(reduction_axes)
168176

tests/torch/ptq/test_weights_compression.py

+41
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf import SensitivityMetric
1818
from nncf.quantization import compress_weights
1919
from nncf.torch import wrap_model
20+
from nncf.torch.quantization.layers import WeightsDecompressor
2021

2122
DATA_BASED_SENSITIVITY_METRICS = (
2223
SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
@@ -52,6 +53,32 @@ def forward(self, input_ids):
5253
return res
5354

5455

56+
class NestedMatMul(torch.nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32))
60+
61+
def forward(self, input):
62+
return input @ self.w
63+
64+
65+
class FunctionalModel(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32))
69+
self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32))
70+
self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3)))
71+
self.nested_matmul = NestedMatMul()
72+
73+
def forward(self, input_):
74+
x = input_.to(torch.float32)
75+
x = x @ self.matmul_w
76+
x = self.nested_matmul(x)
77+
x = F.conv2d(x, self.conv_w)
78+
x = F.conv_transpose2d(x, self.conv_tr_w)
79+
return x
80+
81+
5582
class ConvolutionModel(torch.nn.Module):
5683
def __init__(self):
5784
super().__init__()
@@ -98,6 +125,20 @@ def test_compress_weights():
98125
assert n_compressed_weights == n_target_modules
99126

100127

128+
def test_compress_weights_functional_model():
129+
model = FunctionalModel()
130+
131+
input_ids = torch.randint(0, 10, [1, 3, 300, 300])
132+
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)
133+
compressed_model = compress_weights(wrapped_model)
134+
135+
n_compressed_weights = 0
136+
for layer in compressed_model.nncf.external_op.values():
137+
if isinstance(layer, WeightsDecompressor):
138+
n_compressed_weights += 1
139+
assert n_compressed_weights == 4
140+
141+
101142
def test_compress_weights_conv():
102143
model = ConvolutionModel()
103144

0 commit comments

Comments
 (0)