Skip to content

Commit efe85a2

Browse files
committed
Added tests for load_in_4bit
1 parent 1e87775 commit efe85a2

File tree

1 file changed

+48
-17
lines changed

1 file changed

+48
-17
lines changed

tests/openvino/test_quantization.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,34 @@ class OVWeightCompressionTest(unittest.TestCase):
155155

156156
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),)
157157
SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = (
158-
(OVModelForCausalLM, "opt125m", 64, 477),
158+
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 46),
159+
)
160+
161+
LOAD_IN_4_BITS_SCOPE = (
162+
(
163+
OVModelForCausalLM,
164+
"hf-internal-testing/tiny-random-gpt2",
165+
dict(mode=nncf.CompressWeightsMode.INT4_ASYM, group_size=-1, ratio=0.8),
166+
16,
167+
),
168+
(
169+
OVModelForCausalLM,
170+
"hf-internal-testing/tiny-random-gpt2",
171+
dict(
172+
mode=nncf.CompressWeightsMode.INT4_ASYM,
173+
group_size=-1,
174+
ignored_scope=nncf.IgnoredScope(names=["__module.model.transformer.h.2.mlp.c_fc/aten::addmm/MatMul"]),
175+
),
176+
6,
177+
),
178+
(
179+
OVModelForCausalLM,
180+
"hf-internal-testing/tiny-random-gpt2",
181+
dict(mode=nncf.CompressWeightsMode.INT4_ASYM, group_size=-1, ratio=0.8, all_layers=True),
182+
22,
183+
),
184+
# TODO: uncomment after fix
185+
# (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", dict(mode=nncf.CompressWeightsMode.INT4_SYM, group_size=-1, ratio=0.8, sensitivity_metric=nncf.SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE, dataset="ptb"), 16),
159186
)
160187

161188
SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = (
@@ -249,37 +276,26 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
249276

250277
@parameterized.expand(SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
251278
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
252-
def test_ovmodel_4bit_weight_compression_stateful(self, model_cls, model_name, expected_int8, expected_int4):
279+
def test_ovmodel_8bit_weight_compression_stateful(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
253280
task = model_cls.export_feature
254281

255282
with tempfile.TemporaryDirectory() as tmp_dir:
256283
model_id = MODEL_NAMES[model_name]
257284
transformers_model = model_cls.from_pretrained(model_id, export=True, stateful=True)
258-
tokenizer = AutoTokenizer.from_pretrained(model_id)
285+
tokenizer = AutoTokenizer.from_pretrained(model_name)
259286
if tokenizer.pad_token is None:
260287
tokenizer.pad_token = tokenizer.eos_token
261288

262289
quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
263-
quantizer.quantize(
264-
save_directory=tmp_dir,
265-
weights_only=True,
266-
quantization_config=OVWeightQuantizationConfig(mode=nncf.CompressWeightsMode.INT4_SYM, ratio=0.8),
267-
)
290+
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
268291
model = model_cls.from_pretrained(tmp_dir)
269-
self.assertTrue(model.stateful)
270-
self.assertTrue(model.use_cache)
271292

272-
_, num_int8, num_int4 = get_num_quantized_nodes(model)
273-
self.assertEqual(expected_int8, num_int8)
274-
self.assertEqual(expected_int4, num_int4)
293+
_, num_int8, _ = get_num_quantized_nodes(model)
294+
self.assertEqual(expected_ov_int8, num_int8)
275295

276296
tokens = tokenizer("This is a sample input", return_tensors="pt")
277297
outputs = model(**tokens)
278-
279298
self.assertTrue("logits" in outputs)
280-
self.assertTrue("past_key_values" in outputs)
281-
self.assertIsInstance(outputs.past_key_values, tuple)
282-
self.assertTrue(len(outputs.past_key_values) == 1 and len(outputs.past_key_values[0]) == 0)
283299

284300
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
285301
def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
@@ -298,6 +314,21 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
298314
_, num_int8, _ = get_num_quantized_nodes(model)
299315
self.assertEqual(expected_ov_int8[i], num_int8)
300316

317+
@parameterized.expand(LOAD_IN_4_BITS_SCOPE)
318+
def test_ovmodel_4bit_auto_compression(self, model_cls, model_id, quantization_config, expected_ov_int4):
319+
task = model_cls.export_feature
320+
321+
with tempfile.TemporaryDirectory() as tmp_dir:
322+
model = model_cls.from_pretrained(
323+
model_id, export=True, load_in_4bit=True, quantization_config=quantization_config
324+
)
325+
tokenizer = AutoTokenizer.from_pretrained(model_id)
326+
if tokenizer.pad_token is None:
327+
tokenizer.pad_token = tokenizer.eos_token
328+
329+
_, num_int4, _ = get_num_quantized_nodes(model)
330+
self.assertEqual(expected_ov_int4, num_int4)
331+
301332
@parameterized.expand(((OVModelForCausalLM, "gpt2"),))
302333
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
303334
def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type):

0 commit comments

Comments
 (0)