Skip to content

Commit 4c821ad

Browse files
committed
Fixed issues with models larger than 1B. Added tests.
1 parent c2f373f commit 4c821ad

File tree

5 files changed

+41
-9
lines changed

5 files changed

+41
-9
lines changed

optimum/exporters/openvino/convert.py

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp
9595
"ratio": compression_ratio,
9696
},
9797
}
98+
9899
model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option])
99100

100101
compress_to_fp16 = compression_option == "fp16"

optimum/intel/openvino/modeling_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def _from_transformers(
287287

288288
compression_option = None
289289
if load_in_8bit is not None:
290-
compression_option = "int8" if load_in_8bit else "fp32"
290+
compression_option = "fp32"
291291

292292
main_export(
293293
model_name_or_path=model_id,
@@ -304,7 +304,7 @@ def _from_transformers(
304304
)
305305

306306
config.save_pretrained(save_dir_path)
307-
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=False, **kwargs)
307+
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)
308308

309309
@classmethod
310310
def _to_load(

optimum/intel/openvino/modeling_decoder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def _from_transformers(
264264
task = task + "-with-past"
265265

266266
compression_option = None
267-
if load_in_8bit is not None and not load_in_4bit:
268-
compression_option = "int8" if load_in_8bit else "fp32"
267+
if load_in_8bit is not None or load_in_4bit is not None:
268+
compression_option = "fp32"
269269
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
270270
main_export(
271271
model_name_or_path=model_id,
@@ -574,7 +574,7 @@ def _from_pretrained(
574574
local_files_only=local_files_only,
575575
)
576576

577-
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
577+
model = cls.load_model(model_cache_path, load_in_8bit=False if load_in_4bit else load_in_8bit)
578578

579579
model_type = config.model_type.replace("_", "-")
580580
if model_type == "bloom":

optimum/intel/openvino/weight_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def compress_decoder_weights(model, quantization_config: Union[OVWeightQuantizat
141141

142142
from optimum.gptq.data import get_dataset, prepare_dataset
143143

144-
dataset = get_dataset(config.dataset, tokenizer)
144+
dataset = get_dataset(config.dataset, tokenizer, seqlen=32)
145145
dataset = prepare_dataset(dataset)
146146
dataset = nncf.Dataset(dataset, lambda x: model.prepare_forward_inputs(**x))
147147

tests/openvino/test_quantization.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
from datasets import load_dataset
2424
from parameterized import parameterized
25+
import openvino.runtime as ov
2526
import nncf
2627
from transformers import (
2728
AutoModelForQuestionAnswering,
@@ -154,7 +155,8 @@ class OVWeightCompressionTest(unittest.TestCase):
154155
)
155156

156157
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),)
157-
SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = (
158+
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = ((OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 16, 136),)
159+
SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = (
158160
(OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 46),
159161
)
160162

@@ -170,7 +172,7 @@ class OVWeightCompressionTest(unittest.TestCase):
170172
"hf-internal-testing/tiny-random-gpt2",
171173
dict(
172174
mode=nncf.CompressWeightsMode.INT4_ASYM,
173-
group_size=-1,
175+
group_size=32,
174176
ignored_scope=nncf.IgnoredScope(names=["__module.model.transformer.h.2.mlp.c_fc/aten::addmm/MatMul"]),
175177
),
176178
6,
@@ -297,7 +299,7 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
297299
outputs = model(**tokens)
298300
self.assertTrue("logits" in outputs)
299301

300-
@parameterized.expand(SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS)
302+
@parameterized.expand(SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
301303
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")
302304
def test_ovmodel_8bit_weight_compression_stateful(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
303305
task = model_cls.export_feature
@@ -351,6 +353,35 @@ def test_ovmodel_4bit_auto_compression(self, model_cls, model_id, quantization_c
351353

352354
_, num_int4, _ = get_num_quantized_nodes(model)
353355
self.assertEqual(expected_ov_int4, num_int4)
356+
357+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS)
358+
def test_ovmodel_4bit_auto_compression_with_custom_dataset(self, model_cls, model_id, expected_int8, expected_int4):
359+
task = model_cls.export_feature
360+
361+
tokenizer = AutoTokenizer.from_pretrained(model_id)
362+
if tokenizer.pad_token is None:
363+
tokenizer.pad_token = tokenizer.eos_token
364+
365+
dataset_name, dataset_config_name, column = _TASK_TO_DATASET[task]
366+
dataset = load_dataset(dataset_name, dataset_config_name, split="test")
367+
368+
def transform_fn(data, tokenizer):
369+
tokenized_text = tokenizer(data[column], return_tensors="np")
370+
input_ids = tokenized_text["input_ids"]
371+
attention_mask = tokenized_text["attention_mask"]
372+
inputs = {}
373+
inputs["input_ids"] = input_ids
374+
inputs["attention_mask"] = attention_mask
375+
batch_size = input_ids.shape[0]
376+
inputs["beam_idx"] = np.arange(batch_size, dtype=int)
377+
return inputs
378+
379+
quantization_dataset = nncf.Dataset(dataset, partial(transform_fn, tokenizer=tokenizer))
380+
model = model_cls.from_pretrained(model_id, export=True, load_in_4bit=True, quantization_config=OVWeightQuantizationConfig(mode=nncf.CompressWeightsMode.INT4_SYM, group_size=-1, ratio=0.8, dataset=quantization_dataset))
381+
382+
_, num_int8, num_int4 = get_num_quantized_nodes(model)
383+
self.assertEqual(expected_int8, num_int8)
384+
self.assertEqual(expected_int4, num_int4)
354385

355386
@parameterized.expand(((OVModelForCausalLM, "gpt2"),))
356387
@unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above")

0 commit comments

Comments
 (0)