11
11
limitations under the License.
12
12
"""
13
13
from enum import Enum
14
- from typing import Dict , List , Tuple , Optional
14
+ from typing import Dict , List , Tuple , Optional , Any
15
15
16
16
import numpy as np
17
17
import torch
24
24
from nncf .torch .checkpoint_loading import OPTIONAL_PARAMETERS_REGISTRY
25
25
from nncf .common .utils .debug import is_debug
26
26
from nncf .torch .functions import clamp
27
+ from nncf .common .graph import NNCFNodeName
27
28
from nncf .common .utils .logger import logger as nncf_logger
28
29
from nncf .common .quantization .structs import QuantizationMode , QuantizerConfig , QuantizerSpec
29
30
from nncf .common .quantization .quantizers import calculate_symmetric_level_ranges
30
31
from nncf .common .quantization .quantizers import calculate_asymmetric_level_ranges
32
+ from nncf .common .quantization .quantizer_setup import QuantizerSetupBase
33
+ from nncf .common .quantization .quantizer_setup import QuantizationPointId
34
+ from nncf .torch .graph .transformations .commands import TargetType
35
+ from nncf .torch .graph .transformations .commands import PTTargetPoint
31
36
from nncf .torch .quantization .quantize_functions import symmetric_quantize , asymmetric_quantize , \
32
37
ExportQuantizeToFakeQuantize , get_scale_zp_from_input_low_input_high , ExportQuantizeToONNXQuantDequant , TuneRange
33
38
from nncf .torch .layer_utils import COMPRESSION_MODULES , CompressionParameter
@@ -52,7 +57,21 @@ def from_str(config_value: str) -> 'HWConfigType':
52
57
raise RuntimeError ("Unknown quantizer ONNX export mode string" )
53
58
54
59
60
+ class PTQSpecStateNames :
61
+ NUM_BITS = 'num_bits'
62
+ MODE = 'mode'
63
+ SIGNED_TO_FORCE = 'signedness_to_force'
64
+ NARROW_RANGE = 'narrow_range'
65
+ HALF_RANGE = 'half_range'
66
+ SCALE_SHAPE = 'scale_shape'
67
+ LOGARITHM_SCALE = 'logarithm_scale'
68
+ IS_QUANTIZED_ON_EXPORT = 'is_quantized_on_export'
69
+ COMPRESSION_LR_MULTIPLIER = 'compression_lr_multiplier'
70
+
71
+
55
72
class PTQuantizerSpec (QuantizerSpec ):
73
+ _state_names = PTQSpecStateNames
74
+
56
75
def __init__ (self , num_bits : int ,
57
76
mode : QuantizationMode ,
58
77
signedness_to_force : Optional [bool ],
@@ -70,6 +89,7 @@ def __init__(self, num_bits: int,
70
89
activation quantizers.
71
90
"""
72
91
super ().__init__ (num_bits , mode , signedness_to_force , narrow_range , half_range )
92
+ self .per_channel = scale_shape != [1 ]
73
93
self .scale_shape = scale_shape
74
94
self .logarithm_scale = logarithm_scale
75
95
self .compression_lr_multiplier = compression_lr_multiplier
@@ -90,6 +110,155 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool,
90
110
is_quantized_on_export ,
91
111
compression_lr_multiplier )
92
112
113
+ def __eq__ (self , other ):
114
+ return self .__dict__ == other .__dict__
115
+
116
+ @classmethod
117
+ def from_state (cls , state : Dict [str , Any ]) -> 'PTQuantizationPoint' :
118
+ """
119
+ Creates the object from its state.
120
+
121
+ :param state: Output of `get_state()` method.
122
+ """
123
+ kwargs = {
124
+ cls ._state_names .NUM_BITS : state ['num_bits' ],
125
+ cls ._state_names .MODE : state ['mode' ],
126
+ cls ._state_names .SIGNED_TO_FORCE : state ['signedness_to_force' ],
127
+ cls ._state_names .NARROW_RANGE : state ['narrow_range' ],
128
+ cls ._state_names .HALF_RANGE : state ['half_range' ],
129
+ cls ._state_names .SCALE_SHAPE : state ['scale_shape' ],
130
+ cls ._state_names .LOGARITHM_SCALE : state ['logarithm_scale' ],
131
+ cls ._state_names .IS_QUANTIZED_ON_EXPORT : state ['is_quantized_on_export' ],
132
+ cls ._state_names .COMPRESSION_LR_MULTIPLIER : state ['compression_lr_multiplier' ]
133
+ }
134
+ return cls (** kwargs )
135
+
136
+ def get_state (self ):
137
+ return {self ._state_names .NUM_BITS : self .num_bits ,
138
+ self ._state_names .MODE : self .mode ,
139
+ self ._state_names .SIGNED_TO_FORCE : self .signedness_to_force ,
140
+ self ._state_names .NARROW_RANGE : self .narrow_range ,
141
+ self ._state_names .HALF_RANGE : self .half_range ,
142
+ self ._state_names .SCALE_SHAPE : self .scale_shape ,
143
+ self ._state_names .LOGARITHM_SCALE : self .logarithm_scale ,
144
+ self ._state_names .IS_QUANTIZED_ON_EXPORT : self .is_quantized_on_export ,
145
+ self ._state_names .COMPRESSION_LR_MULTIPLIER : self .compression_lr_multiplier }
146
+
147
+
148
+ class PTQPointStateNames :
149
+ QSPEC = 'qspec'
150
+ TARGET_POINT = 'target_point'
151
+ NAMES_OF_QUANTIZED_OPS = 'directly_quantized_operator_node_names'
152
+
153
+
154
+ class PTQuantizationPoint :
155
+ _state_names = PTQPointStateNames
156
+
157
+ def __init__ (self , qspec : PTQuantizerSpec , target_point : PTTargetPoint ,
158
+ directly_quantized_operator_node_names : List [NNCFNodeName ]):
159
+ self .qspec = qspec
160
+ self .target_point = target_point
161
+ self .directly_quantized_operator_node_names = directly_quantized_operator_node_names
162
+
163
+ def is_activation_quantization_point (self ) -> bool :
164
+ return not self .is_weight_quantization_point ()
165
+
166
+ def is_weight_quantization_point (self ) -> bool :
167
+ return self .target_point .target_type == TargetType .OPERATION_WITH_WEIGHTS
168
+
169
+ def __str__ (self ):
170
+ return str (self .target_point ) + ' ' + str (self .qspec )
171
+
172
+ def get_state (self ) -> Dict [str , Any ]:
173
+ """
174
+ Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
175
+ represents state of the object.
176
+
177
+ :return: state of the object
178
+ """
179
+ return {
180
+ self ._state_names .TARGET_POINT : self .target_point .get_state (),
181
+ self ._state_names .QSPEC : self .qspec .get_state (),
182
+ self ._state_names .NAMES_OF_QUANTIZED_OPS : self .directly_quantized_operator_node_names
183
+ }
184
+
185
+ @classmethod
186
+ def from_state (cls , state : Dict [str , Any ]) -> 'PTQuantizationPoint' :
187
+ """
188
+ Creates the object from its state.
189
+
190
+ :param state: Output of `get_state()` method.
191
+ """
192
+ kwargs = {
193
+ cls ._state_names .TARGET_POINT : PTTargetPoint .from_state (state [cls ._state_names .TARGET_POINT ]),
194
+ cls ._state_names .QSPEC : PTQuantizerSpec .from_state (state [cls ._state_names .QSPEC ]),
195
+ cls ._state_names .NAMES_OF_QUANTIZED_OPS : state [cls ._state_names .NAMES_OF_QUANTIZED_OPS ]
196
+ }
197
+ return cls (** kwargs )
198
+
199
+
200
+ class PTQSetupStateNames :
201
+ SHARED_INPUT_OPERATION_SET_GROUPS = 'shared_input_operation_set_groups'
202
+ UNIFIED_SCALE_GROUPS = 'unified_scale_groups'
203
+ QUANTIZATION_POINTS = 'quantization_points'
204
+
205
+
206
+ class PTQuantizerSetup (QuantizerSetupBase ):
207
+ _state_names = PTQSetupStateNames
208
+
209
+ def __init__ (self , unified_scale_groups , shared_input_operation_set_groups ):
210
+ super ().__init__ ()
211
+ self .unified_scale_groups = unified_scale_groups
212
+ self .shared_input_operation_set_groups = shared_input_operation_set_groups
213
+ self .quantization_points = {} # type: Dict[QuantizationPointId, PTQuantizationPoint]
214
+
215
+ @classmethod
216
+ def from_state (cls , state : Dict ) -> 'PTQuantizerSetup' :
217
+ """
218
+ Creates the object from its state.
219
+
220
+ :param state: Output of `get_state()` method.
221
+ """
222
+
223
+ def decode_qp (pair ):
224
+ str_qp_id , qp_state = pair
225
+ return int (str_qp_id ), PTQuantizationPoint .from_state (qp_state )
226
+
227
+ def list2set (pair ):
228
+ str_idx , qp_id_list = pair
229
+ return int (str_idx ), set (qp_id_list )
230
+
231
+ unified_scale_groups = dict (map (list2set , state [cls ._state_names .UNIFIED_SCALE_GROUPS ].items ()))
232
+ shared_input_operation_set_groups_state = state [cls ._state_names .SHARED_INPUT_OPERATION_SET_GROUPS ]
233
+ setup = PTQuantizerSetup (unified_scale_groups , shared_input_operation_set_groups_state )
234
+ setup .quantization_points = dict (map (decode_qp , state [cls ._state_names .QUANTIZATION_POINTS ].items ()))
235
+ setup .shared_input_operation_set_groups = dict (map (list2set , shared_input_operation_set_groups_state .items ()))
236
+ return setup
237
+
238
+ def get_state (self ):
239
+ """
240
+ Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
241
+ represents state of the object.
242
+
243
+ :return: state of the object
244
+ """
245
+
246
+ def set2list (pair ):
247
+ i , qp_id_set = pair
248
+ return i , list (qp_id_set )
249
+
250
+ quantization_points_state = {qp_id : qp .get_state () for qp_id , qp in self .quantization_points .items ()}
251
+ unified_scale_groups_state = dict (map (set2list , self .unified_scale_groups .items ()))
252
+ shared_input_operation_set_groups_state = dict (map (set2list , self .shared_input_operation_set_groups .items ()))
253
+ return {
254
+ self ._state_names .QUANTIZATION_POINTS : quantization_points_state ,
255
+ self ._state_names .UNIFIED_SCALE_GROUPS : unified_scale_groups_state ,
256
+ self ._state_names .SHARED_INPUT_OPERATION_SET_GROUPS : shared_input_operation_set_groups_state ,
257
+ }
258
+
259
+ def add_quantization_point (self , qp_id : QuantizationPointId , qp : PTQuantizationPoint ):
260
+ self .quantization_points [qp_id ] = qp
261
+
93
262
94
263
class BaseQuantizer (nn .Module ):
95
264
# pylint:disable=too-many-public-methods
@@ -636,7 +805,6 @@ def get_quantizer_config(self) -> QuantizerConfig:
636
805
per_channel = self .per_channel )
637
806
638
807
639
-
640
808
def get_per_channel_scale_shape (input_shape , is_weights , channel_idx : int = None ):
641
809
scale_shape = [1 for _ in input_shape ]
642
810
if channel_idx is None :
0 commit comments