18
18
from dataclasses import dataclass
19
19
from enum import Enum
20
20
from pathlib import Path
21
- from typing import Any , Dict , List , Optional , Union
21
+ from typing import Any , Dict , List , Optional , Type , Union
22
22
23
23
import torch
24
24
from transformers .utils .quantization_config import QuantizationConfigMixin
@@ -571,9 +571,7 @@ def to_nncf_dict(self) -> Dict[str, Any]:
571
571
mode = "e2m1"
572
572
mode = nncf .CompressWeightsMode (mode )
573
573
574
- awq = None
575
- if self .quant_method == "awq" or self .quant_method == OVQuantizationMethod .AWQ :
576
- awq = True
574
+ awq = True if self .quant_method == OVQuantizationMethod .AWQ else None
577
575
sensitivity_metric = nncf .SensitivityMetric (self .sensitivity_metric ) if self .sensitivity_metric else None
578
576
backup_mode = nncf .BackupMode (self .backup_precision ) if self .backup_precision else None
579
577
result = {
@@ -896,21 +894,22 @@ def __init__(
896
894
machine arbitrary code present in the model repository.
897
895
**kwargs:
898
896
"""
899
- if isinstance (weight_quantization_config , dict ):
900
- weight_quantization_config = OVWeightQuantizationConfig .from_dict (weight_quantization_config )
901
- else :
902
- weight_quantization_config = weight_quantization_config .clone ()
903
- self .weight_quantization_config = weight_quantization_config
897
+ self .weight_quantization_config = self ._initialize_quantization_config (
898
+ weight_quantization_config , OVWeightQuantizationConfig
899
+ )
904
900
wqc = self .weight_quantization_config
905
901
906
- if isinstance (full_quantization_config , dict ):
907
- full_quantization_config = OVQuantizationConfig .from_dict (full_quantization_config )
908
- else :
909
- full_quantization_config = full_quantization_config .clone ()
910
- self .full_quantization_config = full_quantization_config
902
+ self .full_quantization_config = self ._initialize_quantization_config (
903
+ full_quantization_config , OVQuantizationConfig
904
+ )
911
905
fqc = self .full_quantization_config
912
906
913
907
if fqc .dtype in ["f8e4m3" , "f8e5m2" ] and wqc .backup_precision is None :
908
+ # Here we simulate FP8 backup weight compression precision through full quantization: during weight
909
+ # compression step some weighted layers are kept in original precision and later are compressed to FP8
910
+ # during full precision quantization step.
911
+ # The issue with current approach is that if one provides an ignored scope for the full quantization step,
912
+ # then the weights of the layers under this ignored scope won't be compressed to FP8.
914
913
# TODO: remove once there is support for FP8 weight compression in NNCF
915
914
wqc .backup_precision = "none"
916
915
@@ -932,6 +931,21 @@ def __init__(
932
931
933
932
self .post_init ()
934
933
934
+ @staticmethod
935
+ def _initialize_quantization_config (
936
+ config : Union [dict , OVWeightQuantizationConfig , OVQuantizationConfig ],
937
+ config_type : Type [Union [OVWeightQuantizationConfig , OVQuantizationConfig ]],
938
+ ):
939
+ if isinstance (config , dict ):
940
+ return config_type .from_dict (config )
941
+ elif isinstance (config , config_type ):
942
+ return config .clone ()
943
+ else :
944
+ raise ValueError (
945
+ f"Unsupported type of quantization config. Expected either a dictionary or an instance of "
946
+ f"{ config_type } , but found: { type (config )} ."
947
+ )
948
+
935
949
def to_dict (self ):
936
950
result = super ().to_dict ()
937
951
result ["weight_quantization_config" ] = self .weight_quantization_config .to_dict ()
0 commit comments