|
17 | 17 | from nncf import SensitivityMetric
|
18 | 18 | from nncf.quantization import compress_weights
|
19 | 19 | from nncf.torch import wrap_model
|
| 20 | +from nncf.torch.quantization.layers import WeightsDecompressor |
20 | 21 |
|
21 | 22 | DATA_BASED_SENSITIVITY_METRICS = (
|
22 | 23 | SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
|
@@ -52,6 +53,32 @@ def forward(self, input_ids):
|
52 | 53 | return res
|
53 | 54 |
|
54 | 55 |
|
| 56 | +class NestedMatMul(torch.nn.Module): |
| 57 | + def __init__(self): |
| 58 | + super().__init__() |
| 59 | + self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32)) |
| 60 | + |
| 61 | + def forward(self, input): |
| 62 | + return input @ self.w |
| 63 | + |
| 64 | + |
| 65 | +class FunctionalModel(torch.nn.Module): |
| 66 | + def __init__(self): |
| 67 | + super().__init__() |
| 68 | + self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32)) |
| 69 | + self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32)) |
| 70 | + self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3))) |
| 71 | + self.nested_matmul = NestedMatMul() |
| 72 | + |
| 73 | + def forward(self, input_): |
| 74 | + x = input_.to(torch.float32) |
| 75 | + x = x @ self.matmul_w |
| 76 | + x = self.nested_matmul(x) |
| 77 | + x = F.conv2d(x, self.conv_w) |
| 78 | + x = F.conv_transpose2d(x, self.conv_tr_w) |
| 79 | + return x |
| 80 | + |
| 81 | + |
55 | 82 | class ConvolutionModel(torch.nn.Module):
|
56 | 83 | def __init__(self):
|
57 | 84 | super().__init__()
|
@@ -98,6 +125,20 @@ def test_compress_weights():
|
98 | 125 | assert n_compressed_weights == n_target_modules
|
99 | 126 |
|
100 | 127 |
|
| 128 | +def test_compress_weights_functional_model(): |
| 129 | + model = FunctionalModel() |
| 130 | + |
| 131 | + input_ids = torch.randint(0, 10, [1, 3, 300, 300]) |
| 132 | + wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) |
| 133 | + compressed_model = compress_weights(wrapped_model) |
| 134 | + |
| 135 | + n_compressed_weights = 0 |
| 136 | + for layer in compressed_model.nncf.external_op.values(): |
| 137 | + if isinstance(layer, WeightsDecompressor): |
| 138 | + n_compressed_weights += 1 |
| 139 | + assert n_compressed_weights == 4 |
| 140 | + |
| 141 | + |
101 | 142 | def test_compress_weights_conv():
|
102 | 143 | model = ConvolutionModel()
|
103 | 144 |
|
|
0 commit comments