|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
12 |
| -from typing import Tuple |
| 12 | +from typing import Any, Tuple |
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 | import pytest
|
16 | 16 | import torch
|
| 17 | +from torch import nn |
17 | 18 | from torch.quantization.fake_quantize import FakeQuantize
|
18 | 19 |
|
19 | 20 | import nncf
|
|
26 | 27 | from nncf.parameters import StripFormat
|
27 | 28 | from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
|
28 | 29 | from nncf.torch.quantization.layers import AsymmetricQuantizer
|
| 30 | +from nncf.torch.quantization.layers import BaseQuantizer |
| 31 | +from nncf.torch.quantization.layers import BaseWeightsDecompressor |
29 | 32 | from nncf.torch.quantization.layers import PTQuantizerSpec
|
30 | 33 | from nncf.torch.quantization.layers import SymmetricQuantizer
|
31 | 34 | from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
|
@@ -330,6 +333,30 @@ def test_nncf_strip_api(strip_type, do_copy):
|
330 | 333 | assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)
|
331 | 334 |
|
332 | 335 |
|
| 336 | +def check_compression_modules( |
| 337 | + model_: nn.Module, |
| 338 | + expected_module_type: ExtraCompressionModuleType, |
| 339 | + not_expected_module_type: ExtraCompressionModuleType, |
| 340 | + expected_class: Any, |
| 341 | +) -> None: |
| 342 | + """ |
| 343 | + Checks if the given model has the expected compression module registered and not the unexpected one. |
| 344 | + Also verifies that the compression module is of the expected class type. |
| 345 | +
|
| 346 | + :param model_: The model to be checked, which should have an 'nncf' attribute with compression module methods. |
| 347 | + :param expected_module_type: The type of the compression module that is expected to be registered. |
| 348 | + :param not_expected_module_type: The type of the compression module that is not expected to be registered. |
| 349 | + :param expected_class: The class type that the expected compression module should be an instance of. |
| 350 | + """ |
| 351 | + print(model_) |
| 352 | + assert model_.nncf.is_compression_module_registered(expected_module_type) |
| 353 | + assert not model_.nncf.is_compression_module_registered(not_expected_module_type) |
| 354 | + compression_modules_dict = model_.nncf.get_compression_modules_by_type(expected_module_type) |
| 355 | + assert len(compression_modules_dict) == 1 |
| 356 | + compression_module = next(iter(compression_modules_dict.values())) |
| 357 | + assert isinstance(compression_module, expected_class) |
| 358 | + |
| 359 | + |
333 | 360 | @pytest.mark.parametrize(
|
334 | 361 | ("mode", "torch_dtype", "atol"),
|
335 | 362 | (
|
@@ -358,11 +385,23 @@ def test_nncf_strip_lora_model(mode, torch_dtype, atol):
|
358 | 385 | if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
|
359 | 386 | compression_kwargs.update(dict(ratio=1, group_size=4, all_layers=True))
|
360 | 387 | compressed_model = nncf.compress_weights(model, **compression_kwargs)
|
| 388 | + check_compression_modules( |
| 389 | + compressed_model, |
| 390 | + expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, |
| 391 | + not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP, |
| 392 | + expected_class=BaseQuantizer, |
| 393 | + ) |
361 | 394 |
|
362 | 395 | with torch.no_grad():
|
363 | 396 | compressed_output = compressed_model(example)
|
364 | 397 |
|
365 | 398 | strip_compressed_model = nncf.strip(compressed_model, do_copy=True, strip_format=StripFormat.DQ)
|
| 399 | + check_compression_modules( |
| 400 | + strip_compressed_model, |
| 401 | + expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP, |
| 402 | + not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, |
| 403 | + expected_class=BaseWeightsDecompressor, |
| 404 | + ) |
366 | 405 | stripped_output = strip_compressed_model(example)
|
367 | 406 |
|
368 | 407 | assert torch.allclose(compressed_output, stripped_output, atol=atol)
|
0 commit comments