Skip to content

Commit 1762c5c

Browse files
authored
Refactoring of minmax algorithm parameter setup and checking (#3309)
### Changes - Fixed `_override_device` function https://github.com/openvinotoolkit/nncf/blob/5c75e22c2888ebde2a87534c8cb204497899b0b7/nncf/quantization/algorithms/min_max/algorithm.py#L253 - Apply overflow fix only for 8-bit quantization by default https://github.com/openvinotoolkit/nncf/blob/5c75e22c2888ebde2a87534c8cb204497899b0b7/nncf/quantization/advanced_parameters.py#L36 ### Reason for changes The overflow fix should only be applied for 8 bit quantization ### Related tickets None ### Tests test_target_device test_npu_target_device test_overflow_fix
1 parent c40f73e commit 1762c5c

File tree

6 files changed

+80
-86
lines changed

6 files changed

+80
-86
lines changed

nncf/quantization/algorithms/min_max/algorithm.py

+46-53
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,12 @@ def __init__(
208208
self._ignored_scope = IgnoredScope() if ignored_scope is None else ignored_scope
209209
self.quantizer_propagation_rule = quantizer_propagation_rule
210210

211-
# preset definition
212-
if self._preset is None:
213-
if model_type == ModelType.TRANSFORMER:
214-
self._preset = QuantizationPreset.MIXED
215-
else:
216-
self._preset = QuantizationPreset.PERFORMANCE
211+
# validate input parameter types
212+
self._validate_param_types()
217213

218-
self._override_device()
219-
self._set_mode_based_defaults()
220-
self._review_mode_based_defaults()
214+
# set and validate mode based parameters
215+
self._set_mode_based_params()
216+
self._review_mode_based_params()
221217

222218
self._quantization_params = {
223219
QuantizerGroup.WEIGHTS: self._weights_quantization_params,
@@ -238,35 +234,64 @@ def __init__(
238234
self._reset_cache()
239235
self._algorithm_key = f"MMQ_{hash(self)}"
240236

241-
def _override_device(self) -> None:
237+
def _validate_param_types(self) -> None:
238+
"""
239+
Validates the types of the provided quantization parameters.
240+
241+
Raises:
242+
nncf.ParameterNotSupportedError: If the parameter types do not match the expected quantization mode.
243+
"""
244+
expected_cls = QuantizationParameters
245+
if self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2):
246+
expected_cls = FP8QuantizationParameters
247+
248+
for param, name in [
249+
(self._weights_quantization_params, "weights"),
250+
(self._activations_quantization_params, "activations"),
251+
]:
252+
if param and not isinstance(param, expected_cls):
253+
msg = f"Quantization parameters for {name} ({param}) are not supported with the selected mode!"
254+
raise nncf.ParameterNotSupportedError(msg)
255+
256+
def _set_mode_based_params(self) -> None:
242257
"""
243-
Overrides NPU device to use CPU quantization scheme.
258+
Sets parameters for the algorithms based on the provided mode.
244259
"""
245-
if self._target_device == TargetDevice.NPU:
246-
act_bits, weight_bits = 8, 8
260+
if self._mode is None:
261+
if self._preset is None:
262+
if self._model_type == ModelType.TRANSFORMER:
263+
self._preset = QuantizationPreset.MIXED
264+
else:
265+
self._preset = QuantizationPreset.PERFORMANCE
266+
267+
act_bits = DEFAULT_QCONFIG.num_bits
268+
weight_bits = DEFAULT_QCONFIG.num_bits
247269
if self._activations_quantization_params and self._activations_quantization_params.num_bits:
248270
act_bits = self._activations_quantization_params.num_bits
249271
if self._weights_quantization_params and self._weights_quantization_params.num_bits:
250272
weight_bits = self._weights_quantization_params.num_bits
251273

252-
if act_bits == 8 and weight_bits == 8:
253-
self._target_device == TargetDevice.CPU
274+
quant_scheme_a8w8 = act_bits == 8 and weight_bits == 8
275+
if self._target_device == TargetDevice.NPU and quant_scheme_a8w8:
276+
self._target_device = TargetDevice.CPU
254277
nncf_logger.debug("Target device NPU was changed to CPU!")
255278

256-
def _set_mode_based_defaults(self) -> None:
257-
"""
258-
Sets defaults for the algorithms based on the provided mode.
259-
"""
279+
if self._overflow_fix is None and not quant_scheme_a8w8:
280+
self._overflow_fix = OverflowFix.DISABLE
281+
nncf_logger.debug("Overflow fix was disabled because quantization scheme is not A8W8.")
282+
elif self._preset is None:
283+
self._preset = QuantizationPreset.PERFORMANCE
284+
260285
mode_based_defaults = MODE_BASED_DEFAULTS[self._mode]
261286
for field in dataclasses.fields(mode_based_defaults):
262287
self_name = "_" + field.name
263288
default_value = getattr(mode_based_defaults, field.name)
264289
if getattr(self, self_name) is None:
265290
setattr(self, self_name, default_value)
266291

267-
def _review_mode_based_defaults(self):
292+
def _review_mode_based_params(self):
268293
"""
269-
Reviews default values because mode option doesn't support them.
294+
Reviews parameter values because mode option doesn't support them.
270295
"""
271296
if self._mode in (QuantizationMode.FP8_E4M3, QuantizationMode.FP8_E5M2):
272297
nncf_logger.warning(f"You're using experimental option mode with {self._mode} value.")
@@ -287,38 +312,6 @@ def _review_mode_based_defaults(self):
287312
msg = "quantize_outputs option is not supported with the mode option!"
288313
raise nncf.ParameterNotSupportedError(msg)
289314

290-
if isinstance(self._weights_quantization_params, QuantizationParameters):
291-
msg = (
292-
"quantization_params option for weights with "
293-
f"{self._weights_quantization_params} "
294-
"value is not supported with the mode option!"
295-
)
296-
raise nncf.ParameterNotSupportedError(msg)
297-
298-
if isinstance(self._activations_quantization_params, QuantizationParameters):
299-
msg = (
300-
"quantization_params option for activations with "
301-
f"{self._activations_quantization_params} "
302-
"value is not supported with the mode option!"
303-
)
304-
raise nncf.ParameterNotSupportedError(msg)
305-
elif self._mode is None:
306-
if isinstance(self._weights_quantization_params, FP8QuantizationParameters):
307-
msg = (
308-
"quantization_params option for weights with "
309-
f"{self._weights_quantization_params} "
310-
"value is not supported with the mode: None option!"
311-
)
312-
raise nncf.ParameterNotSupportedError(msg)
313-
314-
if isinstance(self._activations_quantization_params, FP8QuantizationParameters):
315-
msg = (
316-
"quantization_params option for activations with "
317-
f"{self._activations_quantization_params} "
318-
"value is not supported with the mode: None option!"
319-
)
320-
raise nncf.ParameterNotSupportedError(msg)
321-
322315
def _reset_cache(self) -> None:
323316
"""
324317
Marks cache by noninitialized values. Needs to be called when the new quantizer setup is needed.

tests/common/quantization/test_minmax.py

+34
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,37 @@ def fill_qsetup_mock(self, *args):
245245
for _ in range(run_nums):
246246
algo._get_quantization_target_points(None, None)
247247
assert find_called == fill_called == 2
248+
249+
250+
@pytest.mark.parametrize(
251+
"target_device", [target_device for target_device in TargetDevice if target_device != TargetDevice.NPU]
252+
)
253+
def test_target_device(target_device):
254+
min_max_algo = MinMaxQuantization(target_device=target_device)
255+
assert min_max_algo._target_device == target_device
256+
257+
258+
@pytest.mark.parametrize("num_bits, ref_hw_target_device", zip([8, 4], [TargetDevice.CPU, TargetDevice.NPU]))
259+
def test_npu_target_device(num_bits, ref_hw_target_device):
260+
min_max_algo = MinMaxQuantization(
261+
target_device=TargetDevice.NPU,
262+
activations_quantization_params=QuantizationParameters(num_bits=num_bits),
263+
weights_quantization_params=QuantizationParameters(num_bits=num_bits),
264+
)
265+
assert min_max_algo._target_device == ref_hw_target_device
266+
267+
268+
@pytest.mark.parametrize("activation_bits", [8, 4])
269+
@pytest.mark.parametrize("weight_bits", [8, 4])
270+
def test_overflow_fix(activation_bits, weight_bits):
271+
quant_scheme_a8w8 = activation_bits == 8 and weight_bits == 8
272+
273+
min_max_algo = MinMaxQuantization(
274+
activations_quantization_params=QuantizationParameters(num_bits=activation_bits),
275+
weights_quantization_params=QuantizationParameters(num_bits=weight_bits),
276+
)
277+
278+
if quant_scheme_a8w8:
279+
assert min_max_algo._overflow_fix == OverflowFix.FIRST_LAYER
280+
else:
281+
assert min_max_algo._overflow_fix == OverflowFix.DISABLE

tests/onnx/quantization/test_ptq_params.py

-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand
2828
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
2929
from nncf.parameters import TargetDevice
30-
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
3130
from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend
3231
from nncf.scopes import IgnoredScope
3332
from tests.common.quantization.metatypes import CatTestMetatype
@@ -49,13 +48,6 @@ def get_ignored_patterns(device: TargetDevice = TargetDevice.ANY) -> GraphPatter
4948
return PatternsManager.get_full_ignored_pattern_graph(backend=BackendType.ONNX, device=device)
5049

5150

52-
@pytest.mark.parametrize("target_device", TargetDevice)
53-
def test_target_device(target_device):
54-
min_max_algo = MinMaxQuantization(target_device=target_device)
55-
min_max_algo._backend_entity = ONNXMinMaxAlgoBackend()
56-
assert min_max_algo._target_device == target_device
57-
58-
5951
class TestPTQParams(TemplateTestPTQParams):
6052
def get_algo_backend(self):
6153
return ONNXMinMaxAlgoBackend()

tests/openvino/native/quantization/test_ptq_params.py

-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from nncf.common.graph.patterns.manager import PatternsManager
1818
from nncf.common.graph.transformations.commands import TargetType
1919
from nncf.common.graph.transformations.commands import TransformationType
20-
from nncf.common.hardware.config import HW_CONFIG_TYPE_TARGET_DEVICE_MAP
2120
from nncf.common.utils.backend import BackendType
2221
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConcatMetatype
2322
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
@@ -27,7 +26,6 @@
2726
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
2827
from nncf.openvino.graph.transformations.commands import OVTargetPoint
2928
from nncf.parameters import TargetDevice
30-
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
3129
from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend
3230
from nncf.scopes import IgnoredScope
3331
from tests.common.quantization.metatypes import CatTestMetatype
@@ -48,13 +46,6 @@ def get_ignored_patterns(device: TargetDevice = TargetDevice.ANY) -> GraphPatter
4846
return PatternsManager.get_full_ignored_pattern_graph(backend=BackendType.OPENVINO, device=device)
4947

5048

51-
@pytest.mark.parametrize("target_device", [TargetDevice.CPU, TargetDevice.GPU, TargetDevice.NPU])
52-
def test_target_device(target_device):
53-
min_max_algo = MinMaxQuantization(target_device=target_device)
54-
min_max_algo._backend_entity = OVMinMaxAlgoBackend()
55-
assert min_max_algo._target_device.value == HW_CONFIG_TYPE_TARGET_DEVICE_MAP[target_device.value]
56-
57-
5849
class TestPTQParams(TemplateTestPTQParams):
5950
def get_algo_backend(self):
6051
return OVMinMaxAlgoBackend()

tests/torch/fx/test_ptq_params.py

-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
2121
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
2222
from nncf.parameters import TargetDevice
23-
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
2423
from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend
2524
from nncf.scopes import IgnoredScope
2625
from nncf.torch.graph.graph import PTNNCFGraph
@@ -48,13 +47,6 @@ def get_ignored_patterns(device: TargetDevice = TargetDevice.ANY) -> GraphPatter
4847
return PatternsManager.get_full_ignored_pattern_graph(backend=BackendType.TORCH_FX, device=device)
4948

5049

51-
@pytest.mark.parametrize("target_device", TargetDevice)
52-
def test_target_device(target_device):
53-
min_max_algo = MinMaxQuantization(target_device=target_device)
54-
min_max_algo._backend_entity = FXMinMaxAlgoBackend()
55-
assert min_max_algo._target_device == target_device
56-
57-
5850
class TestPTQParams(TemplateTestPTQParams):
5951
def get_algo_backend(self):
6052
return FXMinMaxAlgoBackend()

tests/torch/ptq/test_ptq_params.py

-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from nncf.common.graph.transformations.commands import TransformationType
2020
from nncf.common.utils.backend import BackendType
2121
from nncf.parameters import TargetDevice
22-
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
2322
from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
2423
from nncf.scopes import IgnoredScope
2524
from nncf.torch.graph.graph import PTNNCFGraph
@@ -68,13 +67,6 @@ def forward(self, x):
6867
return self.depthwise_conv(x)
6968

7069

71-
@pytest.mark.parametrize("target_device", TargetDevice)
72-
def test_target_device(target_device):
73-
min_max_algo = MinMaxQuantization(target_device=target_device)
74-
min_max_algo._backend_entity = PTMinMaxAlgoBackend()
75-
assert min_max_algo._target_device == target_device
76-
77-
7870
class TestPTQParams(TemplateTestPTQParams):
7971
def get_algo_backend(self):
8072
return PTMinMaxAlgoBackend()

0 commit comments

Comments
 (0)