@@ -208,16 +208,12 @@ def __init__(
208
208
self ._ignored_scope = IgnoredScope () if ignored_scope is None else ignored_scope
209
209
self .quantizer_propagation_rule = quantizer_propagation_rule
210
210
211
- # preset definition
212
- if self ._preset is None :
213
- if model_type == ModelType .TRANSFORMER :
214
- self ._preset = QuantizationPreset .MIXED
215
- else :
216
- self ._preset = QuantizationPreset .PERFORMANCE
211
+ # validate input parameter types
212
+ self ._validate_param_types ()
217
213
218
- self . _override_device ()
219
- self ._set_mode_based_defaults ()
220
- self ._review_mode_based_defaults ()
214
+ # set and validate mode based parameters
215
+ self ._set_mode_based_params ()
216
+ self ._review_mode_based_params ()
221
217
222
218
self ._quantization_params = {
223
219
QuantizerGroup .WEIGHTS : self ._weights_quantization_params ,
@@ -238,35 +234,64 @@ def __init__(
238
234
self ._reset_cache ()
239
235
self ._algorithm_key = f"MMQ_{ hash (self )} "
240
236
241
- def _override_device (self ) -> None :
237
+ def _validate_param_types (self ) -> None :
238
+ """
239
+ Validates the types of the provided quantization parameters.
240
+
241
+ Raises:
242
+ nncf.ParameterNotSupportedError: If the parameter types do not match the expected quantization mode.
243
+ """
244
+ expected_cls = QuantizationParameters
245
+ if self ._mode in (QuantizationMode .FP8_E4M3 , QuantizationMode .FP8_E5M2 ):
246
+ expected_cls = FP8QuantizationParameters
247
+
248
+ for param , name in [
249
+ (self ._weights_quantization_params , "weights" ),
250
+ (self ._activations_quantization_params , "activations" ),
251
+ ]:
252
+ if param and not isinstance (param , expected_cls ):
253
+ msg = f"Quantization parameters for { name } ({ param } ) are not supported with the selected mode!"
254
+ raise nncf .ParameterNotSupportedError (msg )
255
+
256
+ def _set_mode_based_params (self ) -> None :
242
257
"""
243
- Overrides NPU device to use CPU quantization scheme .
258
+ Sets parameters for the algorithms based on the provided mode .
244
259
"""
245
- if self ._target_device == TargetDevice .NPU :
246
- act_bits , weight_bits = 8 , 8
260
+ if self ._mode is None :
261
+ if self ._preset is None :
262
+ if self ._model_type == ModelType .TRANSFORMER :
263
+ self ._preset = QuantizationPreset .MIXED
264
+ else :
265
+ self ._preset = QuantizationPreset .PERFORMANCE
266
+
267
+ act_bits = DEFAULT_QCONFIG .num_bits
268
+ weight_bits = DEFAULT_QCONFIG .num_bits
247
269
if self ._activations_quantization_params and self ._activations_quantization_params .num_bits :
248
270
act_bits = self ._activations_quantization_params .num_bits
249
271
if self ._weights_quantization_params and self ._weights_quantization_params .num_bits :
250
272
weight_bits = self ._weights_quantization_params .num_bits
251
273
252
- if act_bits == 8 and weight_bits == 8 :
253
- self ._target_device == TargetDevice .CPU
274
+ quant_scheme_a8w8 = act_bits == 8 and weight_bits == 8
275
+ if self ._target_device == TargetDevice .NPU and quant_scheme_a8w8 :
276
+ self ._target_device = TargetDevice .CPU
254
277
nncf_logger .debug ("Target device NPU was changed to CPU!" )
255
278
256
- def _set_mode_based_defaults (self ) -> None :
257
- """
258
- Sets defaults for the algorithms based on the provided mode.
259
- """
279
+ if self ._overflow_fix is None and not quant_scheme_a8w8 :
280
+ self ._overflow_fix = OverflowFix .DISABLE
281
+ nncf_logger .debug ("Overflow fix was disabled because quantization scheme is not A8W8." )
282
+ elif self ._preset is None :
283
+ self ._preset = QuantizationPreset .PERFORMANCE
284
+
260
285
mode_based_defaults = MODE_BASED_DEFAULTS [self ._mode ]
261
286
for field in dataclasses .fields (mode_based_defaults ):
262
287
self_name = "_" + field .name
263
288
default_value = getattr (mode_based_defaults , field .name )
264
289
if getattr (self , self_name ) is None :
265
290
setattr (self , self_name , default_value )
266
291
267
- def _review_mode_based_defaults (self ):
292
+ def _review_mode_based_params (self ):
268
293
"""
269
- Reviews default values because mode option doesn't support them.
294
+ Reviews parameter values because mode option doesn't support them.
270
295
"""
271
296
if self ._mode in (QuantizationMode .FP8_E4M3 , QuantizationMode .FP8_E5M2 ):
272
297
nncf_logger .warning (f"You're using experimental option mode with { self ._mode } value." )
@@ -287,38 +312,6 @@ def _review_mode_based_defaults(self):
287
312
msg = "quantize_outputs option is not supported with the mode option!"
288
313
raise nncf .ParameterNotSupportedError (msg )
289
314
290
- if isinstance (self ._weights_quantization_params , QuantizationParameters ):
291
- msg = (
292
- "quantization_params option for weights with "
293
- f"{ self ._weights_quantization_params } "
294
- "value is not supported with the mode option!"
295
- )
296
- raise nncf .ParameterNotSupportedError (msg )
297
-
298
- if isinstance (self ._activations_quantization_params , QuantizationParameters ):
299
- msg = (
300
- "quantization_params option for activations with "
301
- f"{ self ._activations_quantization_params } "
302
- "value is not supported with the mode option!"
303
- )
304
- raise nncf .ParameterNotSupportedError (msg )
305
- elif self ._mode is None :
306
- if isinstance (self ._weights_quantization_params , FP8QuantizationParameters ):
307
- msg = (
308
- "quantization_params option for weights with "
309
- f"{ self ._weights_quantization_params } "
310
- "value is not supported with the mode: None option!"
311
- )
312
- raise nncf .ParameterNotSupportedError (msg )
313
-
314
- if isinstance (self ._activations_quantization_params , FP8QuantizationParameters ):
315
- msg = (
316
- "quantization_params option for activations with "
317
- f"{ self ._activations_quantization_params } "
318
- "value is not supported with the mode: None option!"
319
- )
320
- raise nncf .ParameterNotSupportedError (msg )
321
-
322
315
def _reset_cache (self ) -> None :
323
316
"""
324
317
Marks cache by noninitialized values. Needs to be called when the new quantizer setup is needed.
0 commit comments