Skip to content

Commit 9943624

Browse files
committed
Style
1 parent 4c821ad commit 9943624

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

tests/openvino/test_quantization.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ class OVWeightCompressionTest(unittest.TestCase):
155155
)
156156

157157
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),)
158-
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = ((OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 16, 136),)
158+
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = (
159+
(OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 16, 136),
160+
)
159161
SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = (
160162
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 46),
161163
)
@@ -353,18 +355,20 @@ def test_ovmodel_4bit_auto_compression(self, model_cls, model_id, quantization_c
353355

354356
_, num_int4, _ = get_num_quantized_nodes(model)
355357
self.assertEqual(expected_ov_int4, num_int4)
356-
358+
357359
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS)
358-
def test_ovmodel_4bit_auto_compression_with_custom_dataset(self, model_cls, model_id, expected_int8, expected_int4):
360+
def test_ovmodel_4bit_auto_compression_with_custom_dataset(
361+
self, model_cls, model_id, expected_int8, expected_int4
362+
):
359363
task = model_cls.export_feature
360-
364+
361365
tokenizer = AutoTokenizer.from_pretrained(model_id)
362366
if tokenizer.pad_token is None:
363367
tokenizer.pad_token = tokenizer.eos_token
364-
368+
365369
dataset_name, dataset_config_name, column = _TASK_TO_DATASET[task]
366370
dataset = load_dataset(dataset_name, dataset_config_name, split="test")
367-
371+
368372
def transform_fn(data, tokenizer):
369373
tokenized_text = tokenizer(data[column], return_tensors="np")
370374
input_ids = tokenized_text["input_ids"]
@@ -377,7 +381,14 @@ def transform_fn(data, tokenizer):
377381
return inputs
378382

379383
quantization_dataset = nncf.Dataset(dataset, partial(transform_fn, tokenizer=tokenizer))
380-
model = model_cls.from_pretrained(model_id, export=True, load_in_4bit=True, quantization_config=OVWeightQuantizationConfig(mode=nncf.CompressWeightsMode.INT4_SYM, group_size=-1, ratio=0.8, dataset=quantization_dataset))
384+
model = model_cls.from_pretrained(
385+
model_id,
386+
export=True,
387+
load_in_4bit=True,
388+
quantization_config=OVWeightQuantizationConfig(
389+
mode=nncf.CompressWeightsMode.INT4_SYM, group_size=-1, ratio=0.8, dataset=quantization_dataset
390+
),
391+
)
381392

382393
_, num_int8, num_int4 = get_num_quantized_nodes(model)
383394
self.assertEqual(expected_int8, num_int8)

0 commit comments

Comments
 (0)