Skip to content

Commit 1e73450

Browse files
eaidovaPenghuiCheng
authored andcommitted
Fix weights compression for OPenVINO models (huggingface#596)
* hot fix for weights compression * rewrite mcok tests
1 parent 7674e33 commit 1e73450

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

optimum/intel/openvino/modeling_decoder.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ def _from_transformers(
261261
task = task + "-with-past"
262262

263263
# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
264-
if load_in_8bit is None or not quantization_config:
265-
ov_config = None
264+
if load_in_8bit is None and not quantization_config:
265+
ov_export_config = None
266266
else:
267-
ov_config = OVConfig(dtype="fp32")
267+
ov_export_config = OVConfig(dtype="fp32")
268268

269269
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
270270

@@ -279,7 +279,7 @@ def _from_transformers(
279279
local_files_only=local_files_only,
280280
force_download=force_download,
281281
trust_remote_code=trust_remote_code,
282-
ov_config=ov_config,
282+
ov_config=ov_export_config,
283283
stateful=stateful,
284284
)
285285

tests/openvino/test_quantization.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -459,36 +459,64 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
459459
self.assertEqual(0, num_int8)
460460

461461
def test_ovmodel_load_large_model_with_default_compressed_weights(self):
462-
with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
463-
model_mixin_patch.num_parameters.return_value = 2e9
462+
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
463+
mock_tensor = unittest.mock.Mock()
464+
mock_tensor.numel = lambda: 2000000000
465+
mock_tensor.requires_grad = True
466+
model_parameters.return_value = [mock_tensor]
464467
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
465468
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
466469
_ = OVModelForCausalLM.from_pretrained(
467470
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
468471
)
469-
saving_params = {
470-
"model": unittest.mock.ANY,
471-
"path": unittest.mock.ANY,
472-
"compression_option": "int8",
473-
"compression_ratio": None,
474-
}
475-
save_model_patch.aasert_called_with(saving_params)
472+
save_model_patch.assert_called_with(
473+
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(quantization_config={"bits": 8})
474+
)
476475

477476
def test_ovmodel_load_large_model_with_uncompressed_weights(self):
478-
with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
479-
model_mixin_patch.num_parameters.return_value = 2e9
477+
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
478+
mock_tensor = unittest.mock.Mock()
479+
mock_tensor.numel = lambda: 2000000000
480+
mock_tensor.requires_grad = True
481+
model_parameters.return_value = [mock_tensor]
480482
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
481483
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
482484
_ = OVModelForCausalLM.from_pretrained(
483485
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
484486
)
485-
saving_params = {
486-
"model": unittest.mock.ANY,
487-
"path": unittest.mock.ANY,
488-
"compression_option": "fp32",
489-
"compression_ratio": None,
490-
}
491-
save_model_patch.aasert_called_with(saving_params)
487+
save_model_patch.assert_called_with(
488+
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
489+
)
490+
491+
def test_ovmodel_load_large_model_with_additional_quantization_config(self):
492+
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
493+
mock_tensor = unittest.mock.Mock()
494+
mock_tensor.numel = lambda: 2000000000
495+
mock_tensor.requires_grad = True
496+
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
497+
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
498+
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
499+
_ = OVModelForCausalLM.from_pretrained(
500+
MODEL_NAMES["llama"],
501+
export=True,
502+
compile=False,
503+
use_cache=False,
504+
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
505+
)
506+
# quantization will be performed later, using load_model
507+
save_model_patch.assert_called_with(
508+
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
509+
)
510+
compression_params = {
511+
"mode": nncf.CompressWeightsMode.INT4_SYM,
512+
"ratio": 0.8,
513+
"group_size": -1,
514+
"all_layers": None,
515+
"sensitivity_metric": None,
516+
"dataset": None,
517+
"ignored_scope": None,
518+
}
519+
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)
492520

493521

494522
class OVQuantizerQATest(unittest.TestCase):

0 commit comments

Comments
 (0)