Skip to content

Commit da72001

Browse files
authored
Adding saturation fix parameters to compression state for PT (openvinotoolkit#1128)
* Trying to debug autoq tests failure * First working iteration * New compression state loading and backward compatibility * Rebased * Test NotImplementedError deleted * Little clean up * Compression State versioning * Comparison operators * nas test fix * Composite compression loading * Enum PT compression state versioning * Remove accidentally added files * Changed version to local quantizaer builder one * Fixed pylint * Fixes
1 parent 4b4e949 commit da72001

9 files changed

+506
-167
lines changed

nncf/torch/checkpoint_loading.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from nncf.common.utils.logger import logger as nncf_logger
21-
from nncf.torch.utils import maybe_convert_legacy_names_in_model_state
2221

2322

2423
def load_state(model: torch.nn.Module, state_dict_to_load: dict, is_resume: bool = False,
@@ -45,6 +44,7 @@ def load_state(model: torch.nn.Module, state_dict_to_load: dict, is_resume: bool
4544

4645
model_state_dict = model.state_dict()
4746

47+
from nncf.torch.utils import maybe_convert_legacy_names_in_model_state
4848
maybe_convert_legacy_names_in_model_state(state_dict_to_load)
4949
key_matcher = KeyMatcher(is_resume, state_dict_to_load, model_state_dict, keys_to_ignore)
5050
new_dict = key_matcher.run()

nncf/torch/quantization/algo.py

+146-110
Large diffs are not rendered by default.

nncf/torch/quantization/layers.py

+170-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
limitations under the License.
1212
"""
1313
from enum import Enum
14-
from typing import Dict, List, Tuple, Optional
14+
from typing import Dict, List, Tuple, Optional, Any
1515

1616
import numpy as np
1717
import torch
@@ -24,10 +24,15 @@
2424
from nncf.torch.checkpoint_loading import OPTIONAL_PARAMETERS_REGISTRY
2525
from nncf.common.utils.debug import is_debug
2626
from nncf.torch.functions import clamp
27+
from nncf.common.graph import NNCFNodeName
2728
from nncf.common.utils.logger import logger as nncf_logger
2829
from nncf.common.quantization.structs import QuantizationMode, QuantizerConfig, QuantizerSpec
2930
from nncf.common.quantization.quantizers import calculate_symmetric_level_ranges
3031
from nncf.common.quantization.quantizers import calculate_asymmetric_level_ranges
32+
from nncf.common.quantization.quantizer_setup import QuantizerSetupBase
33+
from nncf.common.quantization.quantizer_setup import QuantizationPointId
34+
from nncf.torch.graph.transformations.commands import TargetType
35+
from nncf.torch.graph.transformations.commands import PTTargetPoint
3136
from nncf.torch.quantization.quantize_functions import symmetric_quantize, asymmetric_quantize, \
3237
ExportQuantizeToFakeQuantize, get_scale_zp_from_input_low_input_high, ExportQuantizeToONNXQuantDequant, TuneRange
3338
from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter
@@ -52,7 +57,21 @@ def from_str(config_value: str) -> 'HWConfigType':
5257
raise RuntimeError("Unknown quantizer ONNX export mode string")
5358

5459

60+
class PTQSpecStateNames:
61+
NUM_BITS = 'num_bits'
62+
MODE = 'mode'
63+
SIGNED_TO_FORCE = 'signedness_to_force'
64+
NARROW_RANGE = 'narrow_range'
65+
HALF_RANGE = 'half_range'
66+
SCALE_SHAPE = 'scale_shape'
67+
LOGARITHM_SCALE = 'logarithm_scale'
68+
IS_QUANTIZED_ON_EXPORT = 'is_quantized_on_export'
69+
COMPRESSION_LR_MULTIPLIER = 'compression_lr_multiplier'
70+
71+
5572
class PTQuantizerSpec(QuantizerSpec):
73+
_state_names = PTQSpecStateNames
74+
5675
def __init__(self, num_bits: int,
5776
mode: QuantizationMode,
5877
signedness_to_force: Optional[bool],
@@ -70,6 +89,7 @@ def __init__(self, num_bits: int,
7089
activation quantizers.
7190
"""
7291
super().__init__(num_bits, mode, signedness_to_force, narrow_range, half_range)
92+
self.per_channel = scale_shape != [1]
7393
self.scale_shape = scale_shape
7494
self.logarithm_scale = logarithm_scale
7595
self.compression_lr_multiplier = compression_lr_multiplier
@@ -90,6 +110,155 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool,
90110
is_quantized_on_export,
91111
compression_lr_multiplier)
92112

113+
def __eq__(self, other):
114+
return self.__dict__ == other.__dict__
115+
116+
@classmethod
117+
def from_state(cls, state: Dict[str, Any]) -> 'PTQuantizationPoint':
118+
"""
119+
Creates the object from its state.
120+
121+
:param state: Output of `get_state()` method.
122+
"""
123+
kwargs = {
124+
cls._state_names.NUM_BITS: state['num_bits'],
125+
cls._state_names.MODE: state['mode'],
126+
cls._state_names.SIGNED_TO_FORCE: state['signedness_to_force'],
127+
cls._state_names.NARROW_RANGE: state['narrow_range'],
128+
cls._state_names.HALF_RANGE: state['half_range'],
129+
cls._state_names.SCALE_SHAPE: state['scale_shape'],
130+
cls._state_names.LOGARITHM_SCALE: state['logarithm_scale'],
131+
cls._state_names.IS_QUANTIZED_ON_EXPORT: state['is_quantized_on_export'],
132+
cls._state_names.COMPRESSION_LR_MULTIPLIER: state['compression_lr_multiplier']
133+
}
134+
return cls(**kwargs)
135+
136+
def get_state(self):
137+
return {self._state_names.NUM_BITS: self.num_bits,
138+
self._state_names.MODE: self.mode,
139+
self._state_names.SIGNED_TO_FORCE: self.signedness_to_force,
140+
self._state_names.NARROW_RANGE: self.narrow_range,
141+
self._state_names.HALF_RANGE: self.half_range,
142+
self._state_names.SCALE_SHAPE: self.scale_shape,
143+
self._state_names.LOGARITHM_SCALE: self.logarithm_scale,
144+
self._state_names.IS_QUANTIZED_ON_EXPORT: self.is_quantized_on_export,
145+
self._state_names.COMPRESSION_LR_MULTIPLIER: self.compression_lr_multiplier}
146+
147+
148+
class PTQPointStateNames:
149+
QSPEC = 'qspec'
150+
TARGET_POINT = 'target_point'
151+
NAMES_OF_QUANTIZED_OPS = 'directly_quantized_operator_node_names'
152+
153+
154+
class PTQuantizationPoint:
155+
_state_names = PTQPointStateNames
156+
157+
def __init__(self, qspec: PTQuantizerSpec, target_point: PTTargetPoint,
158+
directly_quantized_operator_node_names: List[NNCFNodeName]):
159+
self.qspec = qspec
160+
self.target_point = target_point
161+
self.directly_quantized_operator_node_names = directly_quantized_operator_node_names
162+
163+
def is_activation_quantization_point(self) -> bool:
164+
return not self.is_weight_quantization_point()
165+
166+
def is_weight_quantization_point(self) -> bool:
167+
return self.target_point.target_type == TargetType.OPERATION_WITH_WEIGHTS
168+
169+
def __str__(self):
170+
return str(self.target_point) + ' ' + str(self.qspec)
171+
172+
def get_state(self) -> Dict[str, Any]:
173+
"""
174+
Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
175+
represents state of the object.
176+
177+
:return: state of the object
178+
"""
179+
return {
180+
self._state_names.TARGET_POINT: self.target_point.get_state(),
181+
self._state_names.QSPEC: self.qspec.get_state(),
182+
self._state_names.NAMES_OF_QUANTIZED_OPS: self.directly_quantized_operator_node_names
183+
}
184+
185+
@classmethod
186+
def from_state(cls, state: Dict[str, Any]) -> 'PTQuantizationPoint':
187+
"""
188+
Creates the object from its state.
189+
190+
:param state: Output of `get_state()` method.
191+
"""
192+
kwargs = {
193+
cls._state_names.TARGET_POINT: PTTargetPoint.from_state(state[cls._state_names.TARGET_POINT]),
194+
cls._state_names.QSPEC: PTQuantizerSpec.from_state(state[cls._state_names.QSPEC]),
195+
cls._state_names.NAMES_OF_QUANTIZED_OPS: state[cls._state_names.NAMES_OF_QUANTIZED_OPS]
196+
}
197+
return cls(**kwargs)
198+
199+
200+
class PTQSetupStateNames:
201+
SHARED_INPUT_OPERATION_SET_GROUPS = 'shared_input_operation_set_groups'
202+
UNIFIED_SCALE_GROUPS = 'unified_scale_groups'
203+
QUANTIZATION_POINTS = 'quantization_points'
204+
205+
206+
class PTQuantizerSetup(QuantizerSetupBase):
207+
_state_names = PTQSetupStateNames
208+
209+
def __init__(self, unified_scale_groups, shared_input_operation_set_groups):
210+
super().__init__()
211+
self.unified_scale_groups = unified_scale_groups
212+
self.shared_input_operation_set_groups = shared_input_operation_set_groups
213+
self.quantization_points = {} # type: Dict[QuantizationPointId, PTQuantizationPoint]
214+
215+
@classmethod
216+
def from_state(cls, state: Dict) -> 'PTQuantizerSetup':
217+
"""
218+
Creates the object from its state.
219+
220+
:param state: Output of `get_state()` method.
221+
"""
222+
223+
def decode_qp(pair):
224+
str_qp_id, qp_state = pair
225+
return int(str_qp_id), PTQuantizationPoint.from_state(qp_state)
226+
227+
def list2set(pair):
228+
str_idx, qp_id_list = pair
229+
return int(str_idx), set(qp_id_list)
230+
231+
unified_scale_groups = dict(map(list2set, state[cls._state_names.UNIFIED_SCALE_GROUPS].items()))
232+
shared_input_operation_set_groups_state = state[cls._state_names.SHARED_INPUT_OPERATION_SET_GROUPS]
233+
setup = PTQuantizerSetup(unified_scale_groups, shared_input_operation_set_groups_state)
234+
setup.quantization_points = dict(map(decode_qp, state[cls._state_names.QUANTIZATION_POINTS].items()))
235+
setup.shared_input_operation_set_groups = dict(map(list2set, shared_input_operation_set_groups_state.items()))
236+
return setup
237+
238+
def get_state(self):
239+
"""
240+
Returns a dictionary with Python data structures (dict, list, tuple, str, int, float, True, False, None) that
241+
represents state of the object.
242+
243+
:return: state of the object
244+
"""
245+
246+
def set2list(pair):
247+
i, qp_id_set = pair
248+
return i, list(qp_id_set)
249+
250+
quantization_points_state = {qp_id: qp.get_state() for qp_id, qp in self.quantization_points.items()}
251+
unified_scale_groups_state = dict(map(set2list, self.unified_scale_groups.items()))
252+
shared_input_operation_set_groups_state = dict(map(set2list, self.shared_input_operation_set_groups.items()))
253+
return {
254+
self._state_names.QUANTIZATION_POINTS: quantization_points_state,
255+
self._state_names.UNIFIED_SCALE_GROUPS: unified_scale_groups_state,
256+
self._state_names.SHARED_INPUT_OPERATION_SET_GROUPS: shared_input_operation_set_groups_state,
257+
}
258+
259+
def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationPoint):
260+
self.quantization_points[qp_id] = qp
261+
93262

94263
class BaseQuantizer(nn.Module):
95264
# pylint:disable=too-many-public-methods
@@ -636,7 +805,6 @@ def get_quantizer_config(self) -> QuantizerConfig:
636805
per_channel=self.per_channel)
637806

638807

639-
640808
def get_per_channel_scale_shape(input_shape, is_weights, channel_idx: int = None):
641809
scale_shape = [1 for _ in input_shape]
642810
if channel_idx is None:

nncf/torch/quantization/precision_init/adjacent_quantizers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def parse_from_quantizer_setup(self, all_quantizations: Dict[QuantizerId, BaseQu
8080
resulting_tuple = (quant_id, quantizer_module)
8181
if qp.is_weight_quantization_point():
8282
wt_quant_tuples.append(resulting_tuple)
83-
weight_quantized_module_node_name = qp.insertion_point.target_node_name
83+
weight_quantized_module_node_name = qp.target_point.target_node_name
8484
module_scope_per_weight_qp_id[weight_quantized_module_node_name] = qp_id
8585
elif qp.is_activation_quantization_point():
8686
act_quant_tuples.append(resulting_tuple)

nncf/torch/utils.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def maybe_convert_legacy_names_in_model_state(state_dict_to_load: Dict[str, Any]
346346
for old_name, new_name in LEGACY_VS_NEW_BN_MAP.items():
347347
rename_legacy_names_in_state_dict(state_dict_to_load, legacy_names[old_name], old_name, new_name)
348348

349+
349350
def maybe_convert_legacy_names_in_compress_state(compression_state: Dict[str, Any]) -> None:
350351
"""
351352
Convert legacy layer names in compression state in case such names exist.
@@ -359,29 +360,32 @@ def maybe_convert_legacy_names_in_compress_state(compression_state: Dict[str, An
359360
if not controller_state or 'quantization' not in controller_state:
360361
return
361362

362-
qips = controller_state['quantization']['quantizer_setup']['quantization_points']
363-
364-
detected_legacy_names = {
365-
'BatchNorm1d': False,
366-
'BatchNorm2d': False,
367-
'BatchNorm3d': False,
368-
'NNCFBatchNorm': False,
369-
}
363+
from nncf.torch.quantization.algo import QUANTIZER_BUILDER_STATE_VERSION_SAVE_NAME
364+
if not controller_state['quantization'].get(QUANTIZER_BUILDER_STATE_VERSION_SAVE_NAME):
365+
qips = controller_state['quantization']['quantizer_setup']['quantization_points']
366+
367+
detected_legacy_names = {
368+
'BatchNorm1d': False,
369+
'BatchNorm2d': False,
370+
'BatchNorm3d': False,
371+
'NNCFBatchNorm': False,
372+
}
373+
374+
for point in qips.values():
375+
name = point['qip']['target_node_name']
376+
for old_name, new_name in LEGACY_VS_NEW_BN_MAP.items():
377+
if old_name in name and not new_name in name:
378+
detected_legacy_names[old_name] = True
379+
point['qip']['target_node_name'] = name.replace(old_name, new_name)
380+
break
381+
382+
for old_name, was_detected in detected_legacy_names.items():
383+
if was_detected:
384+
new_name = LEGACY_VS_NEW_BN_MAP[old_name]
385+
warning_deprecated('Legacy Batch Norm layer names was detected in quantization setup target'
386+
' point names. All occurrences of `{}` in nodes names was replaced by'
387+
' `{}`'.format(old_name, new_name))
370388

371-
for point in qips.values():
372-
name = point['qip']['target_node_name']
373-
for old_name, new_name in LEGACY_VS_NEW_BN_MAP.items():
374-
if old_name in name and not new_name in name:
375-
detected_legacy_names[old_name] = True
376-
point['qip']['target_node_name'] = name.replace(old_name, new_name)
377-
break
378-
379-
for old_name, was_detected in detected_legacy_names.items():
380-
if was_detected:
381-
new_name = LEGACY_VS_NEW_BN_MAP[old_name]
382-
warning_deprecated('Legacy Batch Norm layer names was detected in quantization setup target point names. '
383-
'All occurrences of `{}` in nodes names was replaced by `{}`'.format(old_name,
384-
new_name))
385389

386390
def get_model_device(model: torch.nn.Module) -> torch.device:
387391
try:

tests/torch/nas/test_all_elasticity.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def test_multi_elasticity_state():
366366
prepare_train_algo_for_resume(training_ctrl)
367367
compression_state = training_ctrl.get_compression_state()
368368

369-
assert compression_state == REF_COMPRESSION_STATE_FOR_TWO_CONV
369+
assert compression_state['ctrl_state'] == REF_COMPRESSION_STATE_FOR_TWO_CONV['ctrl_state']
370+
assert compression_state['builder_state'] == REF_COMPRESSION_STATE_FOR_TWO_CONV['builder_state']
370371

371372

372373
def test_can_restore_from_state():

tests/torch/quantization/test_hawq_precision_init.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -780,10 +780,10 @@ def test_compression_ratio(desc, mocker):
780780
config = desc.create_config()
781781
register_bn_adaptation_init_args(config)
782782
from nncf.torch.quantization.algo import QuantizationBuilder
783-
get_qsetyp_spy = mocker.spy(QuantizationBuilder, '_get_quantizer_setup')
783+
get_single_config_quantizer_setup_spy = mocker.spy(QuantizationBuilder, '_get_single_config_quantizer_setup')
784784
model, ctrl = create_compressed_model_and_algo_for_test(ConvLinear(), config)
785785

786-
quantizer_setup = get_qsetyp_spy.spy_return
786+
quantizer_setup = get_single_config_quantizer_setup_spy.spy_return
787787
weight_qp_id_per_activation_qp_id = ctrl.groups_of_adjacent_quantizers.weight_qp_id_per_activation_qp_id
788788
flops_per_module = model.get_flops_per_module()
789789
ratio_calculator = CompressionRatioCalculator(flops_per_module, quantizer_setup, weight_qp_id_per_activation_qp_id)

0 commit comments

Comments
 (0)