Skip to content

Commit a1d9580

Browse files
Make quantization config contain only serializable properties.
1 parent 20fd761 commit a1d9580

File tree

5 files changed

+340
-254
lines changed

5 files changed

+340
-254
lines changed

optimum/intel/openvino/configuration.py

+117-109
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import json
14+
import copy
15+
import inspect
1516
import logging
1617
from dataclasses import dataclass
1718
from enum import Enum
18-
from typing import Any, Dict, List, Optional, Tuple, Union
19+
from typing import Any, Dict, List, Optional, Union
1920

20-
import datasets
2121
import nncf
2222
import torch
2323
from nncf.quantization.advanced_parameters import OverflowFix
@@ -52,36 +52,6 @@
5252
}
5353

5454

55-
class _replace_properties_values:
56-
"""
57-
A context manager for temporarily overriding an object's properties
58-
"""
59-
60-
def __init__(self, obj, property_names, property_values):
61-
self.obj = obj
62-
self.property_names = property_names
63-
self.new_property_values = property_values
64-
self.old_property_values = [None] * len(property_names)
65-
for i, property_name in enumerate(self.property_names):
66-
self.old_property_values[i] = getattr(obj, property_name)
67-
68-
def __enter__(self):
69-
for property_name, new_property_value in zip(self.property_names, self.new_property_values):
70-
setattr(self.obj, property_name, new_property_value)
71-
72-
def __exit__(self, exc_type, exc_val, exc_tb):
73-
for property_name, old_property_value in zip(self.property_names, self.old_property_values):
74-
setattr(self.obj, property_name, old_property_value)
75-
76-
77-
def _is_serializable(obj):
78-
try:
79-
json.dumps(obj)
80-
return True
81-
except Exception:
82-
return False
83-
84-
8555
@dataclass
8656
class OVQuantizationConfigBase(QuantizationConfigMixin):
8757
"""
@@ -90,53 +60,41 @@ class OVQuantizationConfigBase(QuantizationConfigMixin):
9060

9161
def __init__(
9262
self,
93-
dataset: Optional[Union[str, List[str], nncf.Dataset, datasets.Dataset]] = None,
94-
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
63+
ignored_scope: Optional[dict] = None,
9564
num_samples: Optional[int] = None,
65+
weight_only: Optional[bool] = None,
66+
**kwargs,
9667
):
9768
"""
9869
Args:
99-
dataset (`str or List[str] or nncf.Dataset or datasets.Dataset`, *optional*):
100-
The dataset used for data-aware weight compression or quantization with NNCF.
101-
ignored_scope (`dict or nncf.IgnoredScope`, *optional*):
102-
An ignored scope that defines the list of model nodes to be ignored during quantization.
70+
ignored_scope (`dict`, *optional*):
71+
An ignored scope that defines a list of model nodes to be ignored during quantization. Dictionary
72+
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
10373
num_samples (`int`, *optional*):
10474
The maximum number of samples composing the calibration dataset.
75+
weight_only (`bool`, *optional*):
76+
Used to explicitly specify type of quantization (weight-only of full) to apply.
10577
"""
106-
self.dataset = dataset
107-
if isinstance(ignored_scope, dict):
108-
ignored_scope = nncf.IgnoredScope(**ignored_scope)
78+
if isinstance(ignored_scope, nncf.IgnoredScope):
79+
ignored_scope = ignored_scope.__dict__
10980
self.ignored_scope = ignored_scope
11081
self.num_samples = num_samples
82+
self.weight_only = weight_only
11183

11284
def post_init(self):
113-
if not (self.dataset is None or isinstance(self.dataset, (str, list, nncf.Dataset, datasets.Dataset))):
85+
try:
86+
self.get_ignored_scope_instance()
87+
except Exception as e:
11488
raise ValueError(
115-
"Dataset must be a instance of either string, list of strings, nncf.Dataset or "
116-
f"dataset.Dataset, but found {type(self.dataset)}"
117-
)
118-
if not (self.ignored_scope is None or isinstance(self.ignored_scope, nncf.IgnoredScope)):
119-
raise ValueError(
120-
"Ignored scope must be a instance of either dict, or nncf.IgnoredScope but found "
121-
f"{type(self.dataset)}"
89+
f"Can't create an `IgnoredScope` object from the provided ignored scope dict: {self.ignored_scope}.\n{e}"
12290
)
91+
if not (self.num_samples is None or isinstance(self.num_samples, int) and self.num_samples > 0):
92+
raise ValueError(f"`num_samples` is expected to be a positive integer, but found: {self.num_samples}")
12393

124-
def _to_dict_without_properties(self, property_names: Union[List[str], Tuple[str]]) -> Dict[str, Any]:
125-
"""
126-
Calls to_dict() with given properties overwritten with None. Useful for hiding non-serializable properties.
127-
"""
128-
if len(property_names) == 0:
129-
return super().to_dict()
130-
with _replace_properties_values(self, property_names, [None] * len(property_names)):
131-
result = super().to_dict()
132-
return result
133-
134-
def to_dict(self) -> Dict[str, Any]:
135-
properties_to_omit = [] if _is_serializable(self.dataset) else ["dataset"]
136-
if isinstance(self.ignored_scope, nncf.IgnoredScope):
137-
with _replace_properties_values(self, ["ignored_scope"], [self.ignored_scope.__dict__]):
138-
return self._to_dict_without_properties(properties_to_omit)
139-
return self._to_dict_without_properties(properties_to_omit)
94+
def get_ignored_scope_instance(self) -> nncf.IgnoredScope:
95+
if self.ignored_scope is None:
96+
return nncf.IgnoredScope()
97+
return nncf.IgnoredScope(**copy.deepcopy(self.ignored_scope))
14098

14199

142100
class OVConfig(BaseConfig):
@@ -155,16 +113,11 @@ def __init__(
155113
self.input_info = input_info
156114
self.save_onnx_model = save_onnx_model
157115
self.optimum_version = kwargs.pop("optimum_version", None)
116+
if isinstance(quantization_config, dict):
117+
quantization_config = self._quantization_config_from_dict(quantization_config)
158118
self.quantization_config = quantization_config
159119
self.compression = None # A backward-compatability field for training-time compression parameters
160120

161-
if isinstance(self.quantization_config, dict):
162-
# Config is loaded as dict during deserialization
163-
logger.info(
164-
"`quantization_config` was provided as a dict, in this form it can't be used for quantization. "
165-
"Please provide config as an instance of OVWeightQuantizationConfig or OVQuantizationConfig"
166-
)
167-
168121
bits = (
169122
self.quantization_config.bits if isinstance(self.quantization_config, OVWeightQuantizationConfig) else None
170123
)
@@ -180,12 +133,40 @@ def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
180133
for name, value in model_inputs.items()
181134
]
182135

136+
@staticmethod
137+
def _quantization_config_from_dict(quantization_config: dict) -> OVQuantizationConfigBase:
138+
wq_args = inspect.getfullargspec(OVWeightQuantizationConfig.__init__).args
139+
q_args = inspect.getfullargspec(OVQuantizationConfig.__init__).args
140+
config_keys = quantization_config.keys()
141+
matches_wq_config_signature = all(arg_name in wq_args for arg_name in config_keys)
142+
matches_q_config_signature = all(arg_name in q_args for arg_name in config_keys)
143+
if matches_wq_config_signature == matches_q_config_signature:
144+
weight_only = quantization_config.get("weight_only", None)
145+
if weight_only is None:
146+
logger.warning(
147+
"Can't determine type of OV quantization config. Please specify explicitly whether you intend to "
148+
"run weight-only quantization or not with `weight_only` parameter. Creating an instance of "
149+
"OVWeightQuantizationConfig."
150+
)
151+
return OVWeightQuantizationConfig.from_dict(quantization_config)
152+
matches_wq_config_signature = weight_only
153+
154+
config_type = OVWeightQuantizationConfig if matches_wq_config_signature else OVQuantizationConfig
155+
return config_type.from_dict(quantization_config)
156+
183157
def _to_dict_safe(self, to_diff_dict: bool = False) -> Dict[str, Any]:
158+
class ConfigStub:
159+
def to_dict(self):
160+
return None
161+
162+
def to_diff_dict(self):
163+
return None
164+
184165
if self.quantization_config is None:
185166
# Parent to_dict() implementation does not support quantization_config being None
186-
with _replace_properties_values(self, ("quantization_config",), (OVQuantizationConfigBase(),)):
187-
result = super().to_diff_dict() if to_diff_dict else super().to_dict()
188-
del result["quantization_config"]
167+
self_copy = copy.deepcopy(self)
168+
self_copy.quantization_config = ConfigStub()
169+
result = self_copy.to_diff_dict() if to_diff_dict else self_copy.to_dict()
189170
else:
190171
result = super().to_diff_dict() if to_diff_dict else super().to_dict()
191172
return result
@@ -212,9 +193,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
212193
The number of bits to quantize to.
213194
sym (`bool`, defaults to `False`):
214195
Whether to use symmetric quantization.
215-
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
196+
tokenizer (`str`, *optional*):
216197
The tokenizer used to process the dataset. You can pass either:
217-
- A custom tokenizer object.
218198
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
219199
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
220200
user or organization name, like `dbmdz/bert-base-german-cased`.
@@ -224,6 +204,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
224204
The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
225205
in a list of strings or just use the one from the list ['wikitext','c4','c4-new','ptb','ptb-new'] for LLLMs
226206
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
207+
Alternatively, you can provide data objects via `calibration_dataset` argument
208+
of `OVQuantizer.quantize()` method.
227209
ratio (`float`, defaults to 1.0):
228210
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
229211
and the rest to INT8_ASYM).
@@ -235,32 +217,44 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
235217
The sensitivity metric for assigning quantization precision to layers. In order to
236218
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
237219
ignored_scope (`dict`, *optional*):
238-
An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization.
220+
An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary
221+
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
239222
num_samples (`int`, *optional*):
240223
The maximum number of samples composing the calibration dataset.
241224
quant_method (`str`, defaults of OVQuantizationMethod.DEFAULT):
242225
Weight compression method to apply.
226+
weight_only (`bool`, *optional*):
227+
Used to explicitly specify type of quantization to apply.
228+
weight_only (`bool`, *optional*):
229+
Used to explicitly specify type of quantization (weight-only of full) to apply.
243230
"""
244231

245232
def __init__(
246233
self,
247234
bits: int = 8,
248235
sym: bool = False,
249-
tokenizer: Optional[Any] = None,
250-
dataset: Optional[Union[str, List[str], nncf.Dataset, datasets.Dataset]] = None,
236+
tokenizer: Optional[str] = None,
237+
dataset: Optional[Union[str, List[str]]] = None,
251238
ratio: float = 1.0,
252239
group_size: Optional[int] = None,
253240
all_layers: Optional[bool] = None,
254241
sensitivity_metric: Optional[str] = None,
255-
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
242+
ignored_scope: Optional[dict] = None,
256243
num_samples: Optional[int] = None,
257244
quant_method: Optional[Union[QuantizationMethod, OVQuantizationMethod]] = OVQuantizationMethod.DEFAULT,
245+
weight_only: Optional[bool] = True,
258246
**kwargs,
259247
):
260-
super().__init__(dataset, ignored_scope, num_samples)
248+
if weight_only is False:
249+
logger.warning(
250+
"Trying to create an instance of `OVWeightQuantizationConfig` with `weight_only` being "
251+
"False. Please check your configuration."
252+
)
253+
super().__init__(ignored_scope, num_samples, True)
261254
self.bits = bits
262255
self.sym = sym
263256
self.tokenizer = tokenizer
257+
self.dataset = dataset
264258
self.group_size = group_size or (-1 if bits == 8 else 128)
265259
self.ratio = ratio
266260
self.all_layers = all_layers
@@ -277,6 +271,11 @@ def post_init(self):
277271
raise ValueError("`ratio` must between 0 and 1.")
278272
if self.group_size is not None and self.group_size != -1 and self.group_size <= 0:
279273
raise ValueError("`group_size` must be greater than 0 or equal to -1")
274+
if not (self.dataset is None or isinstance(self.dataset, (str, list))):
275+
raise ValueError(
276+
f"Dataset must be a instance of either string or list of strings, but found {type(self.dataset)}. "
277+
f"If you wish to provide a custom dataset please pass it via `calibration_dataset` argument."
278+
)
280279
if self.dataset is not None and isinstance(self.dataset, str):
281280
llm_datasets = ["wikitext", "c4", "c4-new", "ptb", "ptb-new"]
282281
stable_diffusion_datasets = [
@@ -303,34 +302,31 @@ def post_init(self):
303302
f"For 8-bit quantization, `group_size` is expected to be set to -1, but was set to {self.group_size}"
304303
)
305304

306-
def to_dict(self) -> Dict[str, Any]:
307-
if not _is_serializable(self.tokenizer):
308-
return self._to_dict_without_properties(("tokenizer",))
309-
return super().to_dict()
305+
if self.tokenizer is not None and not isinstance(self.tokenizer, str):
306+
raise ValueError(f"Tokenizer is expected to be a string, but found {self.tokenizer}")
310307

311308

312309
@dataclass
313310
class OVQuantizationConfig(OVQuantizationConfigBase):
314311
def __init__(
315312
self,
316-
dataset: Union[str, List[str], nncf.Dataset, datasets.Dataset],
317-
ignored_scope: Optional[Union[dict, nncf.IgnoredScope]] = None,
313+
ignored_scope: Optional[dict] = None,
318314
num_samples: Optional[int] = 300,
319315
preset: nncf.QuantizationPreset = None,
320316
model_type: nncf.ModelType = nncf.ModelType.TRANSFORMER,
321317
fast_bias_correction: bool = True,
322318
overflow_fix: OverflowFix = OverflowFix.DISABLE,
319+
weight_only: Optional[bool] = False,
323320
**kwargs,
324321
):
325322
"""
326323
Configuration class containing parameters related to model quantization with NNCF. Compared to weight
327324
compression, during quantization both weights and activations are converted to lower precision.
328325
For weight-only model quantization please see OVWeightQuantizationConfig.
329326
Args:
330-
dataset (`str or List[str] or nncf.Dataset or datasets.Dataset`):
331-
A dataset used for quantization parameters calibration. Required parameter.
332-
ignored_scope (`dict or nncf.IgnoredScope`, *optional*):
333-
An ignored scope that defines the list of model nodes to be ignored during quantization.
327+
ignored_scope (`dict`, *optional*):
328+
An ignored scope that defines the list of model nodes to be ignored during quantization. Dictionary
329+
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
334330
num_samples (`int`, *optional*):
335331
The maximum number of samples composing the calibration dataset.
336332
preset (`nncf.QuantizationPreset`, *optional*):
@@ -346,31 +342,43 @@ def __init__(
346342
Whether to apply fast or full bias correction algorithm.
347343
overflow_fix (`nncf.OverflowFix`, default to OverflowFix.DISABLE):
348344
Parameter for controlling overflow fix setting.
345+
weight_only (`bool`, *optional*):
346+
Used to explicitly specify type of quantization (weight-only of full) to apply.
349347
"""
350-
super().__init__(dataset, ignored_scope, num_samples)
348+
if weight_only is True:
349+
logger.warning(
350+
"Trying to create an instance of `OVQuantizationConfig` with `weight_only` being True. "
351+
"Please check your configuration."
352+
)
353+
super().__init__(ignored_scope, num_samples, False)
354+
# TODO: remove checks below once NNCF is updated to 2.10
355+
if isinstance(overflow_fix, str):
356+
overflow_fix = OverflowFix(overflow_fix)
357+
if isinstance(preset, str):
358+
preset = nncf.QuantizationPreset(preset)
359+
351360
self.preset = preset
352361
self.model_type = model_type
353362
self.fast_bias_correction = fast_bias_correction
354363
self.overflow_fix = overflow_fix
355364
self.post_init()
356365

357-
def post_init(self):
358-
"""
359-
Safety checker that arguments are correct
360-
"""
361-
super().post_init()
362-
if self.dataset is None:
363-
raise ValueError(
364-
"`dataset` is needed to compute the activations range during the calibration step and was not provided."
365-
" In case you only want to apply quantization on the weights, please run weight-only quantization."
366-
)
367-
368366
def to_dict(self) -> Dict[str, Any]:
369367
# TODO: remove code below once NNCF is updated to 2.10
370-
overflow_fix_value = None if self.overflow_fix is None else self.overflow_fix.value
371-
preset_value = None if self.preset is None else self.preset.value
372-
with _replace_properties_values(self, ("overflow_fix", "preset"), (overflow_fix_value, preset_value)):
373-
return super().to_dict()
368+
if isinstance(self.overflow_fix, Enum) or isinstance(self.preset, Enum):
369+
overflow_fix_value = (
370+
None
371+
if self.overflow_fix is None
372+
else self.overflow_fix if isinstance(self.overflow_fix, str) else self.overflow_fix.value
373+
)
374+
preset_value = (
375+
None if self.preset is None else self.preset if isinstance(self.preset, str) else self.preset.value
376+
)
377+
self_copy = copy.deepcopy(self)
378+
self_copy.overflow_fix = overflow_fix_value
379+
self_copy.preset = preset_value
380+
return self_copy.to_dict()
381+
return super().to_dict()
374382

375383

376384
def _check_default_4bit_configs(config: PretrainedConfig):

0 commit comments

Comments
 (0)