11
11
12
12
from copy import deepcopy
13
13
from enum import Enum
14
- from typing import Any , Dict , List , Optional , Union
14
+ from typing import Any , Dict , List , Optional
15
15
16
16
import nncf
17
17
from nncf .common .graph import NNCFNode
24
24
25
25
26
26
@api ()
27
- class QuantizationScheme :
27
+ class QuantizationScheme ( StrEnum ) :
28
28
"""
29
29
Basic enumeration for quantization scheme specification.
30
30
@@ -45,7 +45,7 @@ class QuantizerConfig:
45
45
def __init__ (
46
46
self ,
47
47
num_bits : int = QUANTIZATION_BITS ,
48
- mode : Union [ QuantizationScheme , str ] = QuantizationScheme .SYMMETRIC , # TODO(AlexanderDokuchaev): use enum
48
+ mode : QuantizationScheme = QuantizationScheme .SYMMETRIC ,
49
49
signedness_to_force : Optional [bool ] = None ,
50
50
per_channel : bool = QUANTIZATION_PER_CHANNEL ,
51
51
):
@@ -62,18 +62,20 @@ def __init__(
62
62
self .signedness_to_force = signedness_to_force
63
63
self .per_channel = per_channel
64
64
65
- def __eq__ (self , other ):
65
+ def __eq__ (self , other : object ) -> bool :
66
+ if not isinstance (other , QuantizerConfig ):
67
+ return False
66
68
return self .__dict__ == other .__dict__
67
69
68
- def __str__ (self ):
70
+ def __str__ (self ) -> str :
69
71
return "B:{bits} M:{mode} SGN:{signedness} PC:{per_channel}" .format (
70
72
bits = self .num_bits ,
71
73
mode = "S" if self .mode == QuantizationScheme .SYMMETRIC else "A" ,
72
74
signedness = "ANY" if self .signedness_to_force is None else ("S" if self .signedness_to_force else "U" ),
73
75
per_channel = "Y" if self .per_channel else "N" ,
74
76
)
75
77
76
- def __hash__ (self ):
78
+ def __hash__ (self ) -> int :
77
79
return hash (str (self ))
78
80
79
81
def is_valid_requantization_for (self , other : "QuantizerConfig" ) -> bool :
@@ -96,7 +98,7 @@ def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
96
98
return False
97
99
return True
98
100
99
- def compatible_with_a_unified_scale_linked_qconfig (self , linked_qconfig : "QuantizerConfig" ):
101
+ def compatible_with_a_unified_scale_linked_qconfig (self , linked_qconfig : "QuantizerConfig" ) -> bool :
100
102
"""
101
103
For two configs to be compatible in a unified scale scenario, all of their fundamental parameters
102
104
must be aligned.
@@ -155,7 +157,12 @@ class QuantizerSpec:
155
157
"""
156
158
157
159
def __init__ (
158
- self , num_bits : int , mode : QuantizationScheme , signedness_to_force : bool , narrow_range : bool , half_range : bool
160
+ self ,
161
+ num_bits : int ,
162
+ mode : QuantizationScheme ,
163
+ signedness_to_force : Optional [bool ],
164
+ narrow_range : Optional [bool ],
165
+ half_range : bool ,
159
166
):
160
167
"""
161
168
:param num_bits: Bitwidth of the quantization.
@@ -174,7 +181,9 @@ def __init__(
174
181
self .narrow_range = narrow_range
175
182
self .half_range = half_range
176
183
177
- def __eq__ (self , other : "QuantizerSpec" ):
184
+ def __eq__ (self , other : object ) -> bool :
185
+ if not isinstance (other , QuantizerSpec ):
186
+ return False
178
187
return self .__dict__ == other .__dict__
179
188
180
189
@classmethod
@@ -185,7 +194,7 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: b
185
194
class QuantizationConstraints :
186
195
REF_QCONF_OBJ = QuantizerConfig ()
187
196
188
- def __init__ (self , ** kwargs ) :
197
+ def __init__ (self , ** kwargs : Any ) -> None :
189
198
"""
190
199
Use attribute names of QuantizerConfig as arguments
191
200
to set up constraints.
@@ -220,7 +229,7 @@ def get_updated_constraints(self, overriding_constraints: "QuantizationConstrain
220
229
return QuantizationConstraints (** new_dict )
221
230
222
231
@classmethod
223
- def from_config_dict (cls , config_dict : Dict ) -> "QuantizationConstraints" :
232
+ def from_config_dict (cls , config_dict : Dict [ str , Any ] ) -> "QuantizationConstraints" :
224
233
return cls (
225
234
num_bits = config_dict .get ("bits" ),
226
235
mode = config_dict .get ("mode" ),
@@ -264,19 +273,21 @@ class QuantizerId:
264
273
structure.
265
274
"""
266
275
267
- def get_base (self ):
276
+ def get_base (self ) -> str :
268
277
raise NotImplementedError
269
278
270
279
def get_suffix (self ) -> str :
271
280
raise NotImplementedError
272
281
273
- def __str__ (self ):
282
+ def __str__ (self ) -> str :
274
283
return str (self .get_base ()) + self .get_suffix ()
275
284
276
- def __hash__ (self ):
285
+ def __hash__ (self ) -> int :
277
286
return hash ((self .get_base (), self .get_suffix ()))
278
287
279
- def __eq__ (self , other : "QuantizerId" ):
288
+ def __eq__ (self , other : object ) -> bool :
289
+ if not isinstance (other , QuantizerId ):
290
+ return False
280
291
return (self .get_base () == other .get_base ()) and (self .get_suffix () == other .get_suffix ())
281
292
282
293
@@ -299,7 +310,7 @@ class NonWeightQuantizerId(QuantizerId):
299
310
ordinary activation, function and input
300
311
"""
301
312
302
- def __init__ (self , target_node_name : NNCFNodeName , input_port_id = None ):
313
+ def __init__ (self , target_node_name : NNCFNodeName , input_port_id : Optional [ int ] = None ):
303
314
self .target_node_name = target_node_name
304
315
self .input_port_id = input_port_id
305
316
@@ -335,7 +346,7 @@ class QuantizationPreset(StrEnum):
335
346
PERFORMANCE = "performance"
336
347
MIXED = "mixed"
337
348
338
- def get_params_configured_by_preset (self , quant_group : QuantizerGroup ) -> Dict :
349
+ def get_params_configured_by_preset (self , quant_group : QuantizerGroup ) -> Dict [ str , str ] :
339
350
if quant_group == QuantizerGroup .ACTIVATIONS and self == QuantizationPreset .MIXED :
340
351
return {"mode" : QuantizationScheme .ASYMMETRIC }
341
352
return {"mode" : QuantizationScheme .SYMMETRIC }
0 commit comments