Skip to content

Commit 043b549

Browse files
authored
Merge pull request #1 from raghavx1/patch-1
Update template_test_weights_compression.py
2 parents 72936ab + 06d3707 commit 043b549

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

tests/cross_fw/test_templates/template_test_weights_compression.py

+150-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,156 @@
4343

4444
INT4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM)
4545

46-
46+
class TemplateTestFBCAlgorithm:
47+
@pytest.mark.parametrize("mode", SUPPORTED_MODES)
48+
def test_compress_weights(mode):
49+
model = ShortTransformer(8, 16)
50+
dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8
51+
52+
input_ids = torch.randint(0, 10, (8,))
53+
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)
54+
55+
kwargs = {}
56+
if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
57+
kwargs["group_size"] = 4
58+
compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs)
59+
60+
n_compressed_weights = 0
61+
n_target_modules = 0
62+
63+
for _, module in compressed_model.named_children():
64+
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
65+
n_target_modules += 1
66+
if module.weight.dtype == dtype:
67+
n_compressed_weights += 1
68+
69+
assert n_compressed_weights == n_target_modules
70+
71+
@pytest.mark.parametrize("mode", SUPPORTED_MODES)
72+
def test_compress_shared_weights(mocker, mode):
73+
model = ShortTransformer(8, 16, share_weights=True)
74+
dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8
75+
76+
input_ids = torch.randint(0, 10, (8,))
77+
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)
78+
79+
kwargs = {}
80+
if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
81+
kwargs["group_size"] = 4
82+
compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs)
83+
84+
n_compressed_weights = 0
85+
n_target_modules = 0
86+
87+
for _, module in compressed_model.named_children():
88+
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
89+
n_target_modules += 1
90+
if module.weight.dtype == dtype:
91+
n_compressed_weights += 1
92+
93+
assert n_compressed_weights == n_target_modules
94+
assert len(compressed_model.nncf.external_op) == 2
95+
96+
# check that the weight decompressors are called only once
97+
for val in compressed_model.nncf.external_op.values():
98+
mocker.spy(val, "forward")
99+
100+
compressed_model(input_ids)
101+
102+
for val in compressed_model.nncf.external_op.values():
103+
assert val.forward.call_count == 1
104+
105+
@pytest.mark.parametrize("mode", INT8_MODES)
106+
@pytest.mark.parametrize(
107+
"params",
108+
(
109+
{"ratio": 0.5},
110+
{"group_size": 64},
111+
{"all_layers": True},
112+
{"all_layers": False},
113+
*({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS),
114+
{"gptq": True},
115+
{"scale_estimation": True},
116+
{"lora_correction": True},
117+
{"backup_mode": BackupMode.NONE},
118+
{"backup_mode": BackupMode.INT8_ASYM},
119+
{"backup_mode": BackupMode.INT8_SYM},
120+
{"compression_format": CompressionFormat.FQ, "group_size": 64},
121+
{"advanced_parameters": AdvancedCompressionParameters(statistics_path="anything")},
122+
),
123+
)
124+
def test_raise_error_with_unsupported_params_for_int8(mode, params):
125+
dummy_torch_model = EmptyModel()
126+
dummy_input = torch.Tensor()
127+
wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)
128+
with pytest.raises(nncf.ParameterNotSupportedError):
129+
compress_weights(wrapped_model, mode=mode, **params)
130+
131+
132+
@pytest.mark.parametrize("mode", INT4_MODES)
133+
@pytest.mark.parametrize(
134+
"params",
135+
(
136+
{"gptq": True},
137+
{"lora_correction": True},
138+
{"compression_format": CompressionFormat.FQ, "group_size": 64},
139+
),
140+
)
141+
def test_raise_error_with_unsupported_params_for_int4(mode, params):
142+
dummy_torch_model = EmptyModel()
143+
dummy_input = torch.Tensor()
144+
wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)
145+
with pytest.raises(nncf.ParameterNotSupportedError):
146+
compress_weights(wrapped_model, mode=mode, **params)
147+
148+
@pytest.mark.parametrize("mode", UNSUPPORTED_MODES)
149+
def test_raise_error_with_not_int8(mode):
150+
dummy_torch_model = EmptyModel()
151+
dummy_input = torch.Tensor()
152+
wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)
153+
with pytest.raises(nncf.ParameterNotSupportedError):
154+
compress_weights(wrapped_model, mode=mode)
155+
156+
def test_raise_error_for_statistics_caching():
157+
dummy_torch_model = EmptyModel()
158+
dummy_input = torch.Tensor()
159+
wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True)
160+
with pytest.raises(nncf.ParameterNotSupportedError):
161+
compress_weights(wrapped_model, advanced_parameters=AdvancedCompressionParameters(statistics_path="anything"))
162+
163+
def test_get_dtype_attribute_of_parameter():
164+
model = DTypeModel()
165+
dummy_input = torch.randint(0, 10, [3, 3])
166+
wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True)
167+
compressed_model = compress_weights(wrapped_model)
168+
assert compressed_model.weight.dtype == torch.uint8
169+
compressed_model(dummy_input)
170+
assert compressed_model.weight.dtype == torch.uint8
171+
172+
173+
@pytest.mark.parametrize("dtype", ("float16", "float32"))
174+
def test_model_devices_and_precisions(use_cuda, dtype):
175+
if use_cuda and not torch.cuda.is_available():
176+
pytest.skip("Skipping for CPU-only setups")
177+
device = torch.device("cuda" if use_cuda else "cpu")
178+
dtype = torch.float16 if dtype == "float16" else torch.float32
179+
180+
model = MatMulModel().to(device)
181+
if dtype == torch.float16:
182+
model.half()
183+
184+
dummy_input = torch.rand((1, 256), dtype=dtype, device=device)
185+
wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True)
186+
compressed_model = compress_weights(wrapped_model)
187+
result = compressed_model(dummy_input)
188+
189+
# Scale should always be in float16
190+
assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16
191+
# Result should be in the precision of the model
192+
assert result.dtype == dtype
193+
194+
195+
47196
def get_relative_error(weight_1: Tensor, weight_2: Tensor, axis: int = 0) -> Tensor:
48197
diff = (weight_1 - weight_2) ** 2
49198
return fns.mean(diff, axis=axis) / fns.mean(weight_1**2, axis=axis)

0 commit comments

Comments
 (0)