Skip to content

Commit 1ec5605

Browse files
committedMar 20, 2025
Extended test for strip
1 parent 76f3c8e commit 1ec5605

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed
 

‎tests/torch/quantization/test_strip.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Tuple
12+
from typing import Any, Tuple
1313

1414
import numpy as np
1515
import pytest
1616
import torch
17+
from torch import nn
1718
from torch.quantization.fake_quantize import FakeQuantize
1819

1920
import nncf
@@ -26,6 +27,8 @@
2627
from nncf.parameters import StripFormat
2728
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
2829
from nncf.torch.quantization.layers import AsymmetricQuantizer
30+
from nncf.torch.quantization.layers import BaseQuantizer
31+
from nncf.torch.quantization.layers import BaseWeightsDecompressor
2932
from nncf.torch.quantization.layers import PTQuantizerSpec
3033
from nncf.torch.quantization.layers import SymmetricQuantizer
3134
from nncf.torch.quantization.strip import convert_to_torch_fakequantizer
@@ -330,6 +333,30 @@ def test_nncf_strip_api(strip_type, do_copy):
330333
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)
331334

332335

336+
def check_compression_modules(
337+
model_: nn.Module,
338+
expected_module_type: ExtraCompressionModuleType,
339+
not_expected_module_type: ExtraCompressionModuleType,
340+
expected_class: Any,
341+
) -> None:
342+
"""
343+
Checks if the given model has the expected compression module registered and not the unexpected one.
344+
Also verifies that the compression module is of the expected class type.
345+
346+
:param model_: The model to be checked, which should have an 'nncf' attribute with compression module methods.
347+
:param expected_module_type: The type of the compression module that is expected to be registered.
348+
:param not_expected_module_type: The type of the compression module that is not expected to be registered.
349+
:param expected_class: The class type that the expected compression module should be an instance of.
350+
"""
351+
print(model_)
352+
assert model_.nncf.is_compression_module_registered(expected_module_type)
353+
assert not model_.nncf.is_compression_module_registered(not_expected_module_type)
354+
compression_modules_dict = model_.nncf.get_compression_modules_by_type(expected_module_type)
355+
assert len(compression_modules_dict) == 1
356+
compression_module = next(iter(compression_modules_dict.values()))
357+
assert isinstance(compression_module, expected_class)
358+
359+
333360
@pytest.mark.parametrize(
334361
("mode", "torch_dtype", "atol"),
335362
(
@@ -358,11 +385,23 @@ def test_nncf_strip_lora_model(mode, torch_dtype, atol):
358385
if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
359386
compression_kwargs.update(dict(ratio=1, group_size=4, all_layers=True))
360387
compressed_model = nncf.compress_weights(model, **compression_kwargs)
388+
check_compression_modules(
389+
compressed_model,
390+
expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
391+
not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP,
392+
expected_class=BaseQuantizer,
393+
)
361394

362395
with torch.no_grad():
363396
compressed_output = compressed_model(example)
364397

365398
strip_compressed_model = nncf.strip(compressed_model, do_copy=True, strip_format=StripFormat.DQ)
399+
check_compression_modules(
400+
strip_compressed_model,
401+
expected_module_type=ExtraCompressionModuleType.EXTERNAL_OP,
402+
not_expected_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
403+
expected_class=BaseWeightsDecompressor,
404+
)
366405
stripped_output = strip_compressed_model(example)
367406

368407
assert torch.allclose(compressed_output, stripped_output, atol=atol)

0 commit comments

Comments
 (0)