Skip to content

Commit 73adf4a

Browse files
Address comments
1 parent f61b7e8 commit 73adf4a

File tree

5 files changed

+46
-49
lines changed

5 files changed

+46
-49
lines changed

optimum/intel/openvino/configuration.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from dataclasses import dataclass
1919
from enum import Enum
2020
from pathlib import Path
21-
from typing import Any, Dict, List, Optional, Union
21+
from typing import Any, Dict, List, Optional, Type, Union
2222

2323
import torch
2424
from transformers.utils.quantization_config import QuantizationConfigMixin
@@ -571,9 +571,7 @@ def to_nncf_dict(self) -> Dict[str, Any]:
571571
mode = "e2m1"
572572
mode = nncf.CompressWeightsMode(mode)
573573

574-
awq = None
575-
if self.quant_method == "awq" or self.quant_method == OVQuantizationMethod.AWQ:
576-
awq = True
574+
awq = True if self.quant_method == OVQuantizationMethod.AWQ else None
577575
sensitivity_metric = nncf.SensitivityMetric(self.sensitivity_metric) if self.sensitivity_metric else None
578576
backup_mode = nncf.BackupMode(self.backup_precision) if self.backup_precision else None
579577
result = {
@@ -896,21 +894,22 @@ def __init__(
896894
machine arbitrary code present in the model repository.
897895
**kwargs:
898896
"""
899-
if isinstance(weight_quantization_config, dict):
900-
weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_config)
901-
else:
902-
weight_quantization_config = weight_quantization_config.clone()
903-
self.weight_quantization_config = weight_quantization_config
897+
self.weight_quantization_config = self._initialize_quantization_config(
898+
weight_quantization_config, OVWeightQuantizationConfig
899+
)
904900
wqc = self.weight_quantization_config
905901

906-
if isinstance(full_quantization_config, dict):
907-
full_quantization_config = OVQuantizationConfig.from_dict(full_quantization_config)
908-
else:
909-
full_quantization_config = full_quantization_config.clone()
910-
self.full_quantization_config = full_quantization_config
902+
self.full_quantization_config = self._initialize_quantization_config(
903+
full_quantization_config, OVQuantizationConfig
904+
)
911905
fqc = self.full_quantization_config
912906

913907
if fqc.dtype in ["f8e4m3", "f8e5m2"] and wqc.backup_precision is None:
908+
# Here we simulate FP8 backup weight compression precision through full quantization: during weight
909+
# compression step some weighted layers are kept in original precision and later are compressed to FP8
910+
# during full precision quantization step.
911+
# The issue with current approach is that if one provides an ignored scope for the full quantization step,
912+
# then the weights of the layers under this ignored scope won't be compressed to FP8.
914913
# TODO: remove once there is support for FP8 weight compression in NNCF
915914
wqc.backup_precision = "none"
916915

@@ -932,6 +931,21 @@ def __init__(
932931

933932
self.post_init()
934933

934+
@staticmethod
935+
def _initialize_quantization_config(
936+
config: Union[dict, OVWeightQuantizationConfig, OVQuantizationConfig],
937+
config_type: Type[Union[OVWeightQuantizationConfig, OVQuantizationConfig]],
938+
):
939+
if isinstance(config, dict):
940+
return config_type.from_dict(config)
941+
elif isinstance(config, config_type):
942+
return config.clone()
943+
else:
944+
raise ValueError(
945+
f"Unsupported type of quantization config. Expected either a dictionary or an instance of "
946+
f"{config_type}, but found: {type(config)}."
947+
)
948+
935949
def to_dict(self):
936950
result = super().to_dict()
937951
result["weight_quantization_config"] = self.weight_quantization_config.to_dict()

optimum/intel/openvino/quantization.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,6 @@ def _weight_only_quantization(
10141014
model: openvino.runtime.Model,
10151015
quantization_config: Union[OVWeightQuantizationConfig, Dict],
10161016
calibration_dataset: Optional[Union[nncf.Dataset, Iterable]] = None,
1017-
remove_kv_cache_precision_flag: Optional[bool] = True,
10181017
**kwargs,
10191018
) -> openvino.runtime.Model:
10201019
_verify_not_optimized(model)
@@ -1043,13 +1042,7 @@ def _weight_only_quantization(
10431042
**wc_kwargs,
10441043
)
10451044

1046-
if remove_kv_cache_precision_flag:
1047-
# Remove the KV cache compression disabling flag from the model
1048-
if compressed_model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]):
1049-
prev_rt_info = compressed_model.get_rt_info("runtime_options").value
1050-
if prev_rt_info["KV_CACHE_PRECISION"] == "f16":
1051-
prev_rt_info.pop("KV_CACHE_PRECISION")
1052-
compressed_model.set_rt_info(prev_rt_info, "runtime_options")
1045+
_remove_f16_kv_cache_precision_flag(compressed_model)
10531046

10541047
return compressed_model
10551048

@@ -1065,11 +1058,11 @@ def _full_quantization(
10651058
_verify_not_optimized(model)
10661059
q_kwargs = copy.deepcopy(kwargs)
10671060
q_kwargs.update(quantization_config.to_nncf_dict())
1068-
return nncf.quantize(
1069-
model,
1070-
calibration_dataset=calibration_dataset,
1071-
**q_kwargs,
1072-
)
1061+
quantized_model = nncf.quantize(model, calibration_dataset=calibration_dataset, **q_kwargs)
1062+
1063+
_remove_f16_kv_cache_precision_flag(quantized_model)
1064+
1065+
return quantized_model
10731066

10741067

10751068
def _get_operation_const_op(operation, const_port_id: int):
@@ -1201,9 +1194,7 @@ def merge_ignored_scopes(
12011194
wc_config = quantization_config.weight_quantization_config.clone()
12021195
wc_config.ignored_scope = merge_ignored_scopes(wc_config.ignored_scope, quantization_config.ignored_scope)
12031196
wc_dataset = dataset if wc_config.bits != 8 else None
1204-
compressed_model = _weight_only_quantization(
1205-
model, wc_config, wc_dataset, remove_kv_cache_precision_flag=False, **kwargs
1206-
)
1197+
compressed_model = _weight_only_quantization(model, wc_config, wc_dataset, **kwargs)
12071198

12081199
q_config = quantization_config.full_quantization_config.clone()
12091200
q_config.ignored_scope = merge_ignored_scopes(q_config.ignored_scope, quantization_config.ignored_scope)
@@ -1227,3 +1218,13 @@ def _verify_not_optimized(ov_model):
12271218
raise RuntimeError(message_template.format(model_weight_compression_config))
12281219
elif model_quantization_config is not None:
12291220
raise RuntimeError(message_template.format(model_quantization_config))
1221+
1222+
1223+
def _remove_f16_kv_cache_precision_flag(model: openvino.Model) -> openvino.Model:
1224+
# Remove the KV cache compression disabling flag from the model
1225+
if model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]):
1226+
prev_rt_info = model.get_rt_info("runtime_options").value
1227+
if prev_rt_info["KV_CACHE_PRECISION"] == "f16":
1228+
prev_rt_info.pop("KV_CACHE_PRECISION")
1229+
model.set_rt_info(prev_rt_info, "runtime_options")
1230+
return model

tests/openvino/test_exporters_cli.py

-4
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,11 @@ def test_exporters_cli_full_quantization(
507507
submodels = [model.encoder, model.decoder]
508508
if model.decoder_with_past is not None:
509509
submodels.append(model.decoder_with_past)
510-
expected_kv_cache_precision_per_model = [None, None, None]
511510
else:
512511
expected_num_weight_nodes_per_model = expected_num_weight_nodes_per_model[:-1]
513512
expected_fake_nodes_per_model = expected_fake_nodes_per_model[:-1]
514-
expected_kv_cache_precision_per_model = [None, "f16"]
515513
elif "text-generation" in task:
516514
submodels = [model]
517-
expected_kv_cache_precision_per_model = ["f16"]
518515
else:
519516
raise Exception("Unexpected task.")
520517

@@ -523,7 +520,6 @@ def test_exporters_cli_full_quantization(
523520
submodels,
524521
expected_num_weight_nodes_per_model,
525522
expected_fake_nodes_per_model,
526-
expected_kv_cache_precision_per_model,
527523
)
528524

529525
def test_exporters_cli_int4_with_local_model_and_default_config(self):

tests/openvino/test_quantization.py

-4
Original file line numberDiff line numberDiff line change
@@ -342,17 +342,14 @@ def test_ov_model_static_quantization_with_auto_dataset(
342342
submodels = [ov_model.encoder.model, ov_model.decoder.model]
343343
if ov_model.decoder_with_past is not None:
344344
submodels.append(ov_model.decoder_with_past.model)
345-
expected_kv_cache_precision_per_model = [None, None, None]
346345
else:
347346
expected_num_weight_nodes_per_model = expected_num_weight_nodes_per_model[:-1]
348347
expected_fake_nodes_per_model = expected_fake_nodes_per_model[:-1]
349-
expected_kv_cache_precision_per_model = [None, "f16"]
350348

351349
input_features = torch.randn((1, 128, 3000), dtype=torch.float32)
352350
ov_model.generate(input_features)
353351
elif model_cls == OVModelForCausalLM:
354352
submodels = [ov_model]
355-
expected_kv_cache_precision_per_model = ["f16"]
356353

357354
tokenizer = AutoTokenizer.from_pretrained(model_id)
358355
if tokenizer.pad_token is None:
@@ -368,7 +365,6 @@ def test_ov_model_static_quantization_with_auto_dataset(
368365
submodels,
369366
expected_num_weight_nodes_per_model,
370367
expected_fake_nodes_per_model,
371-
expected_kv_cache_precision_per_model,
372368
)
373369

374370

tests/openvino/utils_tests.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,10 @@ def check_compression_state_per_model(
295295
models: List[Union[ov.Model, OVBaseModel]],
296296
expected_num_weight_nodes_per_model: List[Dict[str, int]],
297297
expected_num_fake_nodes_per_model: Optional[List[int]] = None,
298-
expected_kv_cache_precision_per_model: Optional[List[Union[str, None]]] = None,
299298
):
300299
test_case.assertEqual(len(models), len(expected_num_weight_nodes_per_model))
301300
actual_num_weights_per_model = [{}] * len(models)
302301
actual_num_fake_nodes_per_model = [0] * len(models)
303-
actual_kv_cache_precision_per_model = [None] * len(models)
304302
for i, (submodel, expected_num_weight_nodes) in enumerate(zip(models, expected_num_weight_nodes_per_model)):
305303
ov_model = submodel if isinstance(submodel, ov.Model) else submodel.model
306304
num_fake_nodes, num_weight_nodes = get_num_quantized_nodes(ov_model)
@@ -309,19 +307,11 @@ def check_compression_state_per_model(
309307
actual_num_weights_per_model[i] = num_weight_nodes
310308
actual_num_fake_nodes_per_model[i] = num_fake_nodes
311309

312-
if ov_model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]):
313-
actual_kv_cache_precision = ov_model.get_rt_info(["runtime_options", "KV_CACHE_PRECISION"]).value
314-
else:
315-
actual_kv_cache_precision = None
316-
actual_kv_cache_precision_per_model[i] = actual_kv_cache_precision
310+
test_case.assertFalse(ov_model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"]))
317311

318312
# Check weight nodes
319313
test_case.assertEqual(expected_num_weight_nodes_per_model, actual_num_weights_per_model)
320314

321315
# Check fake nodes
322316
if expected_num_fake_nodes_per_model is not None:
323317
test_case.assertEqual(expected_num_fake_nodes_per_model, actual_num_fake_nodes_per_model)
324-
325-
# Check KV cache precision
326-
expected_kv_cache_precision_per_model = expected_kv_cache_precision_per_model or ([None] * len(models))
327-
test_case.assertEqual(expected_kv_cache_precision_per_model, actual_kv_cache_precision_per_model)

0 commit comments

Comments
 (0)