Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hot fix for weights compression #596

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ def _from_transformers(
task = task + "-with-past"

# If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
if load_in_8bit is None or not quantization_config:
ov_config = None
if load_in_8bit is None and not quantization_config:
ov_export_config = None
else:
ov_config = OVConfig(dtype="fp32")
ov_export_config = OVConfig(dtype="fp32")

stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)

Expand All @@ -279,7 +279,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
ov_config=ov_config,
ov_config=ov_export_config,
stateful=stateful,
)

Expand Down
64 changes: 46 additions & 18 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,36 +459,64 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
self.assertEqual(0, num_int8)

def test_ovmodel_load_large_model_with_default_compressed_weights(self):
with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
model_mixin_patch.num_parameters.return_value = 2e9
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
saving_params = {
"model": unittest.mock.ANY,
"path": unittest.mock.ANY,
"compression_option": "int8",
"compression_ratio": None,
}
save_model_patch.aasert_called_with(saving_params)
save_model_patch.assert_called_with(
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(quantization_config={"bits": 8})
)

def test_ovmodel_load_large_model_with_uncompressed_weights(self):
with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
model_mixin_patch.num_parameters.return_value = 2e9
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
saving_params = {
"model": unittest.mock.ANY,
"path": unittest.mock.ANY,
"compression_option": "fp32",
"compression_ratio": None,
}
save_model_patch.aasert_called_with(saving_params)
save_model_patch.assert_called_with(
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
)

def test_ovmodel_load_large_model_with_additional_quantization_config(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"],
export=True,
compile=False,
use_cache=False,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
)
# quantization will be performed later, using load_model
save_model_patch.assert_called_with(
unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"ratio": 0.8,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": None,
}
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)


class OVQuantizerQATest(unittest.TestCase):
Expand Down
Loading