Skip to content

Commit f3f232f

Browse files
Computation of compression parameters via OpenVINO models (#2727)
### Changes - Implemented OpenVINO model graphs which are used for calculation of compressed and decompressed weights. Since these models are compiled, calculation become significantly faster especially for larger models and int4 compression. - This functionality is exposed by two methods at `weight_lowering.py`: - `do_int_quantization()` is used for computing a compressed weight. Possible signatures: - `weight` -> `compressed_weight`, `scale`, (`zero_point` for asymmetric compression) - `weight`, `scale`, (`zero_point`) -> `compressed_weight`, `scale`, (`zero_point`) - `calculate_quantized_dequantized_weight()` is used for computing a decompressed weight. Possible signatures: - `weight` -> `decompressed_weight` - `weight`, `scale`, (`zero_point`) -> `decompressed_weight` - `weight` -> `decompressed_weight`, `compressed_weight`, `scale`, (`zero_point`) - `weight`, `scale`, (`zero_point`) -> `decompressed_weight`, `compressed_weight`, `scale`, (`zero_point`) - Output `scale` and `zero_point` are the same as the ones given as input (if they were given at all). - Computation is done via OV models only if openvino package is installed and input tensors are not torch tensors. - Introduce a new NNCF Tensor backend for storing instances of `openvino.Tensor`. Implementation for this backend is limited by only the required functionality, e.g. addition of OV Tensors is not supported because it is not needed. - Introduction of OV Tensors is required for seamless handling of tensors in `bf16`, `u4` and `i4` data types. For example, `bf16` constants are read from an OpenVINO LLM and given as inputs to a compressing OpenVINO model. `u4` and `i4` compressed weights are seamlessly inserted into the resulting compressed OpenVINO model. - Added `as_numpy_tensor()` method to convert an NNCF Tensor to numpy backend. Currently only OV -> NP conversion is required. - All calculations are aligned with reference numpy implementation. Some performance and memory sacrifices had to be made for such alignment. Data-free asymmetric compression: ![image](https://github.com/user-attachments/assets/efd76b2f-1a3e-4037-8165-0bd5812de94d) Data-free symmetric compression: ![image](https://github.com/user-attachments/assets/c61b98c6-cc96-4125-b21e-90c7d0827e22) Data-aware compression: ![image](https://github.com/user-attachments/assets/b9823594-9915-4ca5-9e50-7bffa6777104) ### Reason for changes Reducing model compression time. Only OpenVINO model compression backend is affected. ### Related tickets 139047 ### Tests - `tests/openvino/native/quantization/test_ov_modeling_compression.py::test_quantization_alignment` -- check aligment with reference numpy implementation - `tests/openvino/native/test_openvino_modeling.py` -- checks OV modeling framework hyperparameters - `tests/openvino/native/test_tensor.py` -- NNCF OV Tensor backend tests Validation jobs: - `NNCF/job/manual/job/post_training_weight_compression/299/` - `NNCF/job/nightly/job/test_examples/650` - OVVP validation ✅ - optimum-intel test job https://github.com/huggingface/optimum-intel/actions/runs/12912964434/job/36009036879?pr=734
1 parent b6f2e75 commit f3f232f

32 files changed

+2195
-293
lines changed

docs/api/source/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def collect_api_entities() -> APIInfo:
145145
"nncf.tensor.functions.torch_linalg",
146146
"nncf.tensor.functions.torch_io",
147147
"nncf.tensor.functions.numpy_io",
148+
"nncf.tensor.functions.ov_numeric",
148149
]
149150

150151
with mock(mock_modules):

nncf/common/logging/logger.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,41 @@
1111

1212
import logging
1313
import sys
14-
from typing import Set
14+
from functools import lru_cache
15+
from typing import cast
16+
17+
18+
class NNCFLogger(logging.Logger):
19+
def __init__(self, name: str, level: int = logging.NOTSET):
20+
super().__init__(name, level)
21+
22+
@lru_cache(None)
23+
def _log_once(self, level: int, msg: str) -> None:
24+
self.log(level, msg)
25+
26+
def debug_once(self, msg: str) -> None:
27+
"""
28+
Log a message at the DEBUG level, ensuring the message is logged only once.
29+
"""
30+
self._log_once(logging.DEBUG, msg)
31+
32+
def info_once(self, msg: str) -> None:
33+
"""
34+
Log a message at the INFO level, ensuring the message is logged only once.
35+
"""
36+
self._log_once(logging.INFO, msg)
37+
38+
def warning_once(self, msg: str) -> None:
39+
"""
40+
Log a message at the WARNING level, ensuring the message is logged only once.
41+
"""
42+
self._log_once(logging.WARNING, msg)
43+
1544

1645
NNCF_LOGGER_NAME = "nncf"
1746

18-
nncf_logger = logging.getLogger(NNCF_LOGGER_NAME)
47+
logging.setLoggerClass(NNCFLogger)
48+
nncf_logger = cast(NNCFLogger, logging.getLogger(NNCF_LOGGER_NAME))
1949
nncf_logger.propagate = False
2050

2151
stdout_handler = logging.StreamHandler(sys.stdout)
@@ -60,16 +90,6 @@ def disable_logging() -> None:
6090
nncf_logger.handlers = []
6191

6292

63-
class DuplicateFilter:
64-
def __init__(self) -> None:
65-
self.msgs: Set[str] = set()
66-
67-
def filter(self, rec: logging.LogRecord) -> bool:
68-
retval = rec.msg not in self.msgs
69-
self.msgs.add(rec.msg)
70-
return retval
71-
72-
7393
NNCFDeprecationWarning = FutureWarning
7494

7595

nncf/common/utils/backend.py

+15
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
TModel = TypeVar("TModel")
1818

19+
try:
20+
import openvino # type: ignore # noqa: F401
21+
22+
_OPENVINO_AVAILABLE = True
23+
except ImportError:
24+
_OPENVINO_AVAILABLE = False
25+
1926

2027
class BackendType(Enum):
2128
TORCH = "Torch"
@@ -159,3 +166,11 @@ def copy_model(model: TModel) -> TModel:
159166
model = TFModelTransformer(model).transform(TFTransformationLayout())
160167
return model
161168
return deepcopy(model)
169+
170+
171+
def is_openvino_available() -> bool:
172+
"""
173+
Check if OpenVINO is available.
174+
:return: True if openvino package is installed, False otherwise.
175+
"""
176+
return _OPENVINO_AVAILABLE

nncf/common/utils/caching.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
import copy
12+
import inspect
13+
from contextlib import contextmanager
14+
from functools import wraps
15+
from typing import Any, Callable, Dict, Iterator, TypeVar, cast
16+
17+
18+
class ResultsCache:
19+
"""
20+
A container for results decorated with @cache_results decorator.
21+
"""
22+
23+
def __init__(self) -> None:
24+
self._enabled = True
25+
# Stores the results of the decorated function
26+
self._cache: Dict[Any, Any] = {}
27+
# Stores the number of times the cached result was accessed
28+
self._access_count: Dict[Any, int] = {}
29+
30+
def enable(self) -> None:
31+
self._enabled = True
32+
33+
def disable(self) -> None:
34+
self._enabled = False
35+
36+
def enabled(self) -> bool:
37+
return self._enabled
38+
39+
def access_count(self) -> Dict[Any, int]:
40+
return copy.deepcopy(self._access_count)
41+
42+
def clear(self) -> None:
43+
self._cache.clear()
44+
self._access_count.clear()
45+
46+
def __getitem__(self, key: Any) -> Any:
47+
self._access_count[key] += 1
48+
return self._cache[key]
49+
50+
def __setitem__(self, key: Any, value: Any) -> None:
51+
self._access_count[key] = 0
52+
self._cache[key] = value
53+
54+
def __contains__(self, key: Any) -> bool:
55+
return key in self._cache
56+
57+
58+
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
59+
60+
61+
def cache_results(cache: ResultsCache) -> Callable[[TFunc], TFunc]:
62+
"""
63+
Decorator to cache results of a function. When decorated function is called with the same set of arguments, it
64+
will return the cached result instead of recomputing it. If it was the first call with such set of arguments, the
65+
result will be computed and stored in the cache. The cache is stored in the `cache` object. Function arguments
66+
must be hashable.
67+
68+
:param cache: A cache container where results will be stored.
69+
"""
70+
71+
def decorator(func: TFunc) -> TFunc:
72+
@wraps(func)
73+
def wrapper(*args: Any, **kwargs: Any) -> Any:
74+
if not cache.enabled():
75+
return func(*args, **kwargs)
76+
sig = inspect.signature(func)
77+
new_kwargs = {name: arg for name, arg in zip(sig.parameters, args)}
78+
new_kwargs.update(kwargs)
79+
cache_key = (func.__name__, frozenset(new_kwargs.items()))
80+
if cache_key in cache:
81+
return cache[cache_key]
82+
result = func(*args, **kwargs)
83+
cache[cache_key] = result
84+
return result
85+
86+
return cast(TFunc, wrapper)
87+
88+
return decorator
89+
90+
91+
@contextmanager
92+
def disable_results_caching(cache: ResultsCache) -> Iterator[None]:
93+
"""
94+
Context manager to disable caching of results for a block of code.
95+
96+
:param cache: A cache container where results are stored.
97+
"""
98+
if cache.enabled():
99+
cache.disable()
100+
yield
101+
cache.enable()
102+
else:
103+
yield

nncf/openvino/graph/node_utils.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import openvino.runtime as ov
16+
import openvino.runtime.op as op
1617
import openvino.runtime.opset13 as opset
1718

1819
import nncf
@@ -41,6 +42,8 @@
4142
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
4243
from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype
4344
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
45+
from nncf.tensor import Tensor
46+
from nncf.tensor import TensorBackend
4447

4548
InplaceInsertionFnType = Callable[[ov.Node, int, str], ov.Node]
4649

@@ -97,26 +100,27 @@ def get_number_if_op(model: ov.Model) -> int:
97100
"""
98101

99102
def cnt_if_op(model: ov.Model, cnt: int) -> int:
100-
for op in model.get_ops():
101-
if get_node_metatype(op) == OVIfMetatype:
103+
for model_op in model.get_ops():
104+
if get_node_metatype(model_op) == OVIfMetatype:
102105
cnt += 1
103-
cnt = cnt_if_op(op.get_function(0), cnt)
104-
cnt = cnt_if_op(op.get_function(1), cnt)
106+
cnt = cnt_if_op(model_op.get_function(0), cnt)
107+
cnt = cnt_if_op(model_op.get_function(1), cnt)
105108
return cnt
106109

107110
return cnt_if_op(model, 0)
108111

109112

110-
def get_const_value(const_node: ov.Node) -> np.ndarray:
113+
def get_const_value(const_node: ov.Node, cast_bf16_to_fp32: bool = True) -> np.ndarray:
111114
"""
112115
Returns the constant tensor for the node.
113116
This method is applicable only for the floating-point constant data.
114117
115118
:param const_node: OpenVINO node.
119+
:param cast_bf16_to_fp32: Whether to cast bf16 node data to fp32 or not. If False and the node contains bf16 data,
120+
the resulting bf16 value will be returned encoded inside a numpy.float16 array.
116121
:return: The constant value.
117122
"""
118-
if const_node.get_element_type() == ov.Type.bf16:
119-
# Fixed FP32 data type as the result for BF16 constant
123+
if const_node.get_element_type() == ov.Type.bf16 and cast_bf16_to_fp32:
120124
return const_node.get_data(dtype=np.float32)
121125
return const_node.data
122126

@@ -635,3 +639,42 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple
635639
channel_axis = activations_layout.index(OVLayoutElem.C_IN)
636640

637641
return channel_axis
642+
643+
644+
def convert_op(node: ov.Node, target_dtype: ov.Type) -> ov.Node:
645+
"""
646+
Return a subgraph which converts the given node output to the target data type. If the output is already in the
647+
target data type then the given node is returned.
648+
649+
:param node: The input node to convert.
650+
:param target_dtype: The target data type to convert the input node to.
651+
:return: The converted node.
652+
"""
653+
if node.get_element_type() == target_dtype:
654+
return node
655+
return opset.convert(node, target_dtype)
656+
657+
658+
def non_convertable_divide_op(a: ov.Node, b: ov.Node) -> ov.Node:
659+
"""
660+
Creates a "non-convertable" divide operation. It won't be converted to a*(1/b).
661+
"""
662+
divide_node = a / b
663+
divide_node.get_rt_info()["nonconvertable_divide_0"] = True
664+
return divide_node
665+
666+
667+
def create_ov_const_from_tensor(x: Tensor, dtype: ov.Type, name: Optional[str] = None) -> op.Constant:
668+
"""
669+
Create an OpenVINO Constant node from the given tensor.
670+
:param x: Data tensor. Supports NumPy and OV tensor backends. If x backend is OV, the constant node is created
671+
directly from underlying OV tensor.
672+
:param dtype: Data type of the constant.
673+
:param name: Optional name of the constant.
674+
:return: OpenVINO Constant node.
675+
"""
676+
if x.backend == TensorBackend.ov:
677+
assert x.data.get_element_type() == dtype
678+
return opset.constant(x.data, name=name, shared_memory=True)
679+
const = opset.constant(x.data, dtype=dtype, name=name)
680+
return const
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from nncf.openvino.optimized_functions.functions import astype as astype
13+
from nncf.openvino.optimized_functions.functions import do_int_quantization as do_int_quantization
14+
from nncf.openvino.optimized_functions.functions import quantize_dequantize_weight as quantize_dequantize_weight
15+
from nncf.openvino.optimized_functions.models import OVModelParameters as OVModelParameters
16+
from nncf.openvino.optimized_functions.models import clear_ov_model_cache as clear_ov_model_cache

0 commit comments

Comments
 (0)