18
18
from dataclasses import asdict , dataclass , field
19
19
from enum import Enum
20
20
from pathlib import Path
21
- from typing import Dict , List , Optional , Tuple , Union
21
+ from typing import Any , Dict , List , Optional , Tuple , Union
22
22
23
23
from datasets import Dataset
24
24
from packaging .version import Version , parse
@@ -298,6 +298,15 @@ def __post_init__(self):
298
298
)
299
299
self .operators_to_quantize = operators_to_quantize
300
300
301
+ if isinstance (self .format , str ):
302
+ self .format = QuantFormat [self .format ]
303
+ if isinstance (self .mode , str ):
304
+ self .mode = QuantizationMode [self .mode ]
305
+ if isinstance (self .activations_dtype , str ):
306
+ self .activations_dtype = QuantType [self .activations_dtype ]
307
+ if isinstance (self .weights_dtype , str ):
308
+ self .weights_dtype = QuantType [self .weights_dtype ]
309
+
301
310
@staticmethod
302
311
def quantization_type_str (activations_dtype : QuantType , weights_dtype : QuantType ) -> str :
303
312
return (
@@ -984,8 +993,28 @@ def __init__(
984
993
self .opset = opset
985
994
self .use_external_data_format = use_external_data_format
986
995
self .one_external_file = one_external_file
987
- self .optimization = self .dataclass_to_dict (optimization )
988
- self .quantization = self .dataclass_to_dict (quantization )
996
+
997
+ if isinstance (optimization , dict ) and optimization :
998
+ self .optimization = OptimizationConfig (** optimization )
999
+ elif isinstance (optimization , OptimizationConfig ):
1000
+ self .optimization = optimization
1001
+ elif not optimization :
1002
+ self .optimization = None
1003
+ else :
1004
+ raise ValueError (
1005
+ f"Optional argument `optimization` must be a dictionary or an instance of OptimizationConfig, got { type (optimization )} "
1006
+ )
1007
+ if isinstance (quantization , dict ) and quantization :
1008
+ self .quantization = QuantizationConfig (** quantization )
1009
+ elif isinstance (quantization , QuantizationConfig ):
1010
+ self .quantization = quantization
1011
+ elif not quantization :
1012
+ self .quantization = None
1013
+ else :
1014
+ raise ValueError (
1015
+ f"Optional argument `quantization` must be a dictionary or an instance of QuantizationConfig, got { type (quantization )} "
1016
+ )
1017
+
989
1018
self .optimum_version = kwargs .pop ("optimum_version" , None )
990
1019
991
1020
@staticmethod
@@ -1002,3 +1031,17 @@ def dataclass_to_dict(config) -> dict:
1002
1031
v = [elem .name if isinstance (elem , Enum ) else elem for elem in v ]
1003
1032
new_config [k ] = v
1004
1033
return new_config
1034
+
1035
+ def to_dict (self ) -> Dict [str , Any ]:
1036
+ dict_config = {
1037
+ "opset" : self .opset ,
1038
+ "use_external_data_format" : self .use_external_data_format ,
1039
+ "one_external_file" : self .one_external_file ,
1040
+ "optimization" : self .dataclass_to_dict (self .optimization ),
1041
+ "quantization" : self .dataclass_to_dict (self .quantization ),
1042
+ }
1043
+
1044
+ if self .optimum_version :
1045
+ dict_config ["optimum_version" ] = self .optimum_version
1046
+
1047
+ return dict_config
0 commit comments