Skip to content

Commit f332377

Browse files
committed
Remove nncf dependency from openvino configs
1 parent bce36d2 commit f332377

File tree

4 files changed

+39
-42
lines changed

4 files changed

+39
-42
lines changed

.github/workflows/test_openvino.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ jobs:
3535
pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime
3636
- name: Test with Pytest
3737
run: |
38-
pytest tests/openvino/ --ignore test_modeling_basic --durations=0
38+
pytest tests/openvino/ --ignore tests/openvino/test_modeling_basic.py --durations=0
39+
40+
- name: Test basic
41+
run: |
42+
pip uninstall -y nncf
43+
pytest tests/openvino/test_modeling_basic.py
3944
- name: Test openvino-nightly
4045
run: |
4146
pip uninstall -y openvino

optimum/intel/openvino/configuration.py

+16-27
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616
import logging
1717
from dataclasses import dataclass
1818
from enum import Enum
19-
from typing import Any, Dict, List, Optional, Union
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2020

21-
import nncf
2221
import torch
23-
from nncf.quantization.advanced_parameters import OverflowFix
2422
from transformers import PretrainedConfig
2523
from transformers.utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
2624

2725
from optimum.configuration_utils import BaseConfig
2826

2927

28+
if TYPE_CHECKING:
29+
import nncf
30+
3031
logger = logging.getLogger(__name__)
3132

3233
_DEFAULT_4BIT_CONFIGS = {
@@ -75,7 +76,7 @@ def __init__(
7576
weight_only (`bool`, *optional*):
7677
Used to explicitly specify type of quantization (weight-only of full) to apply.
7778
"""
78-
if isinstance(ignored_scope, nncf.IgnoredScope):
79+
if not isinstance(ignored_scope, dict):
7980
ignored_scope = ignored_scope.__dict__
8081
self.ignored_scope = ignored_scope
8182
self.num_samples = num_samples
@@ -91,7 +92,9 @@ def post_init(self):
9192
if not (self.num_samples is None or isinstance(self.num_samples, int) and self.num_samples > 0):
9293
raise ValueError(f"`num_samples` is expected to be a positive integer, but found: {self.num_samples}")
9394

94-
def get_ignored_scope_instance(self) -> nncf.IgnoredScope:
95+
def get_ignored_scope_instance(self) -> "nncf.IgnoredScope":
96+
import nncf
97+
9598
if self.ignored_scope is None:
9699
return nncf.IgnoredScope()
97100
return nncf.IgnoredScope(**copy.deepcopy(self.ignored_scope))
@@ -309,12 +312,12 @@ def post_init(self):
309312
class OVQuantizationConfig(OVQuantizationConfigBase):
310313
def __init__(
311314
self,
315+
sym: bool = False,
312316
ignored_scope: Optional[dict] = None,
313317
num_samples: Optional[int] = 300,
314-
preset: nncf.QuantizationPreset = None,
315-
model_type: nncf.ModelType = nncf.ModelType.TRANSFORMER,
318+
model_type: "nncf.ModelType" = None,
316319
fast_bias_correction: bool = True,
317-
overflow_fix: OverflowFix = OverflowFix.DISABLE,
320+
overflow_fix: str = "disable",
318321
weight_only: Optional[bool] = False,
319322
**kwargs,
320323
):
@@ -323,23 +326,18 @@ def __init__(
323326
compression, during quantization both weights and activations are converted to lower precision.
324327
For weight-only model quantization please see OVWeightQuantizationConfig.
325328
Args:
329+
sym (`bool`, defaults to `False`):
330+
Whether to use symmetric quantization on the activations. Symmetric quantization will be applied on the weights in any case.
326331
ignored_scope (`dict`, *optional*):
327332
An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary
328333
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
329334
num_samples (`int`, *optional*):
330335
The maximum number of samples composing the calibration dataset.
331-
preset (`nncf.QuantizationPreset`, *optional*):
332-
A preset controls the quantization mode (symmetric and asymmetric).
333-
It can take the following values:
334-
- `performance`: Symmetric quantization of weights and activations.
335-
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
336-
Default value is None. In this case, `mixed` preset is used for `transformer`
337-
model type otherwise `performance`.
338336
model_type (`nncf.ModelType`, defaults to nncf.ModelType.TRANSFORMER):
339337
Model type is needed to specify additional patterns in the model. Supported only `transformer` now.
340338
fast_bias_correction (`bool`, defaults to True):
341339
Whether to apply fast or full bias correction algorithm.
342-
overflow_fix (`nncf.OverflowFix`, default to OverflowFix.DISABLE):
340+
overflow_fix (`str`, default to "disable"):
343341
Parameter for controlling overflow fix setting.
344342
weight_only (`bool`, *optional*):
345343
Used to explicitly specify type of quantization (weight-only of full) to apply. Useful when building
@@ -352,33 +350,24 @@ def __init__(
352350
)
353351
super().__init__(ignored_scope, num_samples, False)
354352
# TODO: remove checks below once NNCF is updated to 2.10
355-
if isinstance(overflow_fix, str):
356-
overflow_fix = OverflowFix(overflow_fix)
357-
if isinstance(preset, str):
358-
preset = nncf.QuantizationPreset(preset)
359-
360-
self.preset = preset
353+
self.sym = sym
361354
self.model_type = model_type
362355
self.fast_bias_correction = fast_bias_correction
363356
self.overflow_fix = overflow_fix
364357
self.post_init()
365358

366359
def to_dict(self) -> Dict[str, Any]:
367360
# TODO: remove code below once NNCF is updated to 2.10
368-
if isinstance(self.overflow_fix, Enum) or isinstance(self.preset, Enum):
361+
if isinstance(self.overflow_fix, Enum):
369362
overflow_fix_value = (
370363
None
371364
if self.overflow_fix is None
372365
else self.overflow_fix
373366
if isinstance(self.overflow_fix, str)
374367
else self.overflow_fix.value
375368
)
376-
preset_value = (
377-
None if self.preset is None else self.preset if isinstance(self.preset, str) else self.preset.value
378-
)
379369
self_copy = copy.deepcopy(self)
380370
self_copy.overflow_fix = overflow_fix_value
381-
self_copy.preset = preset_value
382371
return self_copy.to_dict()
383372
return super().to_dict()
384373

optimum/intel/openvino/quantization.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import torch
2727
import transformers
2828
from nncf import CompressWeightsMode, SensitivityMetric
29-
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
29+
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters, OverflowFix
3030
from nncf.torch import register_module
3131
from nncf.torch.initialization import PTInitializingDataLoader
3232
from openvino._offline_transformations import compress_quantize_weights_transformation
@@ -378,10 +378,12 @@ def _quantize_ovbasemodel(
378378
quantization_dataset,
379379
subset_size=quantization_config.num_samples,
380380
ignored_scope=quantization_config.get_ignored_scope_instance(),
381-
model_type=quantization_config.model_type,
382-
preset=quantization_config.preset,
381+
model_type=quantization_config.model_type or nncf.ModelType.TRANSFORMER,
382+
preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED,
383383
fast_bias_correction=quantization_config.fast_bias_correction,
384-
advanced_parameters=nncf.AdvancedQuantizationParameters(overflow_fix=quantization_config.overflow_fix),
384+
advanced_parameters=nncf.AdvancedQuantizationParameters(
385+
overflow_fix=OverflowFix(quantization_config.overflow_fix)
386+
),
385387
**kwargs,
386388
)
387389
self.model.model = quantized_model
@@ -476,10 +478,14 @@ def _quantize_torchmodel(
476478
quantization_dataset,
477479
subset_size=quantization_config.num_samples,
478480
ignored_scope=quantization_config.get_ignored_scope_instance(),
479-
model_type=quantization_config.model_type,
480-
preset=quantization_config.preset,
481+
model_type=quantization_config.model_type or nncf.ModelType.TRANSFORMER,
482+
preset=nncf.QuantizationPreset.PERFORMANCE
483+
if quantization_config.sym
484+
else nncf.QuantizationPreset.MIXED,
481485
fast_bias_correction=quantization_config.fast_bias_correction,
482-
advanced_parameters=nncf.AdvancedQuantizationParameters(overflow_fix=quantization_config.overflow_fix),
486+
advanced_parameters=nncf.AdvancedQuantizationParameters(
487+
overflow_fix=OverflowFix(quantization_config.overflow_fix)
488+
),
483489
**kwargs,
484490
)
485491

tests/openvino/test_quantization.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ class OVQuantizationConfigTest(unittest.TestCase):
748748
OVQuantizationConfig(
749749
ignored_scope={"names": ["op_name"]},
750750
num_samples=100,
751-
preset=nncf.QuantizationPreset.MIXED,
751+
sym=False,
752752
model_type=nncf.ModelType.TRANSFORMER,
753753
fast_bias_correction=True,
754754
overflow_fix=OverflowFix.DISABLE,
@@ -794,7 +794,7 @@ class OVQuantizationConfigTest(unittest.TestCase):
794794
dict(
795795
ignored_scope={"names": ["op_name"]},
796796
num_samples=100,
797-
preset=nncf.QuantizationPreset.MIXED,
797+
sym=False,
798798
model_type=nncf.ModelType.TRANSFORMER,
799799
fast_bias_correction=True,
800800
overflow_fix=OverflowFix.DISABLE,
@@ -834,14 +834,11 @@ def str_to_enum(enum_cls, value):
834834
return
835835
for key, value in loaded_ov_config.quantization_config.to_dict().items():
836836
initial_value = getattr(ov_config.quantization_config, key)
837-
if key == "preset" or key == "overflow_fix":
837+
if key == "overflow_fix":
838838
# TODO: remove once NNCF is updated to 2.10
839839
if getattr(quantization_config, key) is not None:
840840
self.assertTrue(isinstance(value, str))
841-
if key == "preset":
842-
value = str_to_enum(nncf.QuantizationPreset, value)
843-
else:
844-
value = str_to_enum(OverflowFix, value)
841+
value = str_to_enum(OverflowFix, value)
845842
self.assertEqual(value, initial_value)
846843

847844
@parameterized.expand(QUANTIZATION_CONFIG_DICTS)

0 commit comments

Comments
 (0)