Skip to content

Commit 0b80812

Browse files
[mypy] nncf/common/quantization (part 1) (#3192)
### Changes 1. Enable mypy for: - nncf/common/quantization/quantizers.py - nncf/common/quantization/statistics.py - nncf/common/quantization/structs.py - nncf/common/quantization/quantizer_propagation/structs.py - nncf/common/quantization/quantizer_propagation/visualizer.py 2. Inheritance of QuantizationScheme from StrEnum #2629 3. Add QuantizationScheme to Unpickler, to pass pt nightly ### Tests nightly/job/torch_nightly/438/
1 parent 98f4060 commit 0b80812

File tree

9 files changed

+56
-43
lines changed

9 files changed

+56
-43
lines changed

examples/torch/common/restricted_pickle_module.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Unpickler(pickle.Unpickler):
3838
"torch.nn": {"Module"},
3939
"torch.optim.adam": {"Adam"},
4040
"nncf.api.compression": {"CompressionStage", "CompressionLevel"},
41+
"nncf.common.quantization.structs": {"QuantizationScheme"},
4142
"numpy.core.multiarray": {"scalar"}, # numpy<2
4243
"numpy._core.multiarray": {"scalar"}, # numpy>=2
4344
"numpy": {"dtype"},

nncf/common/hardware/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def from_json(cls: type["HWConfig"], path: str) -> List[Dict[str, Any]]:
136136
return cls.from_dict(json_config)
137137

138138
@staticmethod
139-
def get_quantization_mode_from_config_value(str_val: str) -> str:
139+
def get_quantization_mode_from_config_value(str_val: str) -> QuantizationMode:
140140
if str_val == "symmetric":
141141
return QuantizationMode.SYMMETRIC
142142
if str_val == "asymmetric":

nncf/common/quantization/quantizer_propagation/structs.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from __future__ import annotations
13+
1214
from enum import Enum
1315
from typing import List, Optional, Set, Tuple
1416

@@ -67,21 +69,23 @@ def __init__(
6769
this quantizer won't require unified scales.
6870
"""
6971
self.potential_quant_configs: List[QuantizerConfig] = quant_configs
70-
self.affected_edges = set()
72+
self.affected_edges: Set[Tuple[str, str]] = set()
7173
self.affected_ip_nodes: Set[str] = set()
7274
self.propagation_path: PropagationPath = []
7375
self.current_location_node_key = init_location_node_key
74-
self.last_accepting_location_node_key = None
76+
self.last_accepting_location_node_key: Optional[str] = None
7577
self.id = id_
7678
self.unified_scale_type = unified_scale_type
77-
self.affected_operator_nodes = set()
78-
self.quantized_input_sink_operator_nodes = set()
79-
self.downstream_propagating_quantizers = set()
79+
self.affected_operator_nodes: Set[str] = set()
80+
self.quantized_input_sink_operator_nodes: Set[str] = set()
81+
self.downstream_propagating_quantizers: Set[PropagatingQuantizer] = set()
8082

81-
def __eq__(self, other):
83+
def __eq__(self, other: object) -> bool:
84+
if not isinstance(other, PropagatingQuantizer):
85+
return False
8286
return self.id == other.id
8387

84-
def __hash__(self):
88+
def __hash__(self) -> int:
8589
return hash(self.id)
8690

8791

@@ -95,11 +99,11 @@ class QuantizerPropagationStateGraphNodeType(Enum):
9599
class SharedAffectedOpsPropagatingQuantizerGroup:
96100
"""Combines propagating quantizers that share affected operations"""
97101

98-
def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]):
102+
def __init__(self, affecting_prop_quants: Set[PropagatingQuantizer], affected_op_node_keys: Set[str]) -> None:
99103
self.affecting_prop_quants: Set[PropagatingQuantizer] = affecting_prop_quants
100104
self.affected_op_node_keys: Set[str] = affected_op_node_keys
101105

102-
def update(self, other: "SharedAffectedOpsPropagatingQuantizerGroup"):
106+
def update(self, other: SharedAffectedOpsPropagatingQuantizerGroup) -> None:
103107
self.affected_op_node_keys.update(other.affected_op_node_keys)
104108
self.affecting_prop_quants.update(other.affecting_prop_quants)
105109

nncf/common/quantization/quantizer_propagation/visualizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ class QuantizerPropagationVisualizer:
2020
An object performing visualization of the quantizer propagation algorithm's state into a chosen directory.
2121
"""
2222

23-
def __init__(self, dump_dir: str = None):
23+
def __init__(self, dump_dir: str):
2424
self.dump_dir = Path(dump_dir)
2525
if self.dump_dir.exists():
2626
shutil.rmtree(str(self.dump_dir))
2727

2828
def visualize_quantizer_propagation(
2929
self, prop_solver: QuantizerPropagationSolver, prop_graph: QuantizerPropagationStateGraph, iteration: str
30-
):
30+
) -> None:
3131
self.dump_dir.mkdir(parents=True, exist_ok=True)
3232
fname = "quant_prop_iter_{}.dot".format(iteration)
3333
prop_solver.debug_visualize(prop_graph, str(self.dump_dir / Path(fname)))

nncf/common/quantization/quantizers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ def calculate_asymmetric_level_ranges(num_bits: int, narrow_range: bool = False)
6666
return level_low, level_high
6767

6868

69-
def get_num_levels(level_low: int, level_high: int):
69+
def get_num_levels(level_low: int, level_high: int) -> int:
7070
return level_high - level_low + 1

nncf/common/quantization/statistics.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from nncf.common.utils.helpers import create_table
1717

1818

19-
def _proportion_str(num: int, total_count: int):
19+
def _proportion_str(num: int, total_count: int) -> str:
2020
percentage = 100 * (num / max(total_count, 1))
2121
return f"{percentage:.2f} % ({num} / {total_count})"
2222

@@ -170,12 +170,12 @@ def _get_bitwidth_distribution_str(self) -> str:
170170
q_total_num = wq_total_num + aq_total_num
171171

172172
bitwidths = self.num_wq_per_bitwidth.keys() | self.num_aq_per_bitwidth.keys() # union of all bitwidths
173-
bitwidths = sorted(bitwidths, reverse=True)
173+
bitwidths_sorted = sorted(bitwidths, reverse=True)
174174

175175
# Table creation
176176
header = ["Num bits (N)", "N-bits WQs / Placed WQs", "N-bits AQs / Placed AQs", "N-bits Qs / Placed Qs"]
177177
rows = []
178-
for bitwidth in bitwidths:
178+
for bitwidth in bitwidths_sorted:
179179
wq_num = self.num_wq_per_bitwidth.get(bitwidth, 0) # for current bitwidth
180180
aq_num = self.num_aq_per_bitwidth.get(bitwidth, 0) # for current bitwidth
181181
q_num = wq_num + aq_num # for current bitwidth

nncf/common/quantization/structs.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from copy import deepcopy
1313
from enum import Enum
14-
from typing import Any, Dict, List, Optional, Union
14+
from typing import Any, Dict, List, Optional
1515

1616
import nncf
1717
from nncf.common.graph import NNCFNode
@@ -24,7 +24,7 @@
2424

2525

2626
@api()
27-
class QuantizationScheme:
27+
class QuantizationScheme(StrEnum):
2828
"""
2929
Basic enumeration for quantization scheme specification.
3030
@@ -45,7 +45,7 @@ class QuantizerConfig:
4545
def __init__(
4646
self,
4747
num_bits: int = QUANTIZATION_BITS,
48-
mode: Union[QuantizationScheme, str] = QuantizationScheme.SYMMETRIC, # TODO(AlexanderDokuchaev): use enum
48+
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
4949
signedness_to_force: Optional[bool] = None,
5050
per_channel: bool = QUANTIZATION_PER_CHANNEL,
5151
):
@@ -62,18 +62,20 @@ def __init__(
6262
self.signedness_to_force = signedness_to_force
6363
self.per_channel = per_channel
6464

65-
def __eq__(self, other):
65+
def __eq__(self, other: object) -> bool:
66+
if not isinstance(other, QuantizerConfig):
67+
return False
6668
return self.__dict__ == other.__dict__
6769

68-
def __str__(self):
70+
def __str__(self) -> str:
6971
return "B:{bits} M:{mode} SGN:{signedness} PC:{per_channel}".format(
7072
bits=self.num_bits,
7173
mode="S" if self.mode == QuantizationScheme.SYMMETRIC else "A",
7274
signedness="ANY" if self.signedness_to_force is None else ("S" if self.signedness_to_force else "U"),
7375
per_channel="Y" if self.per_channel else "N",
7476
)
7577

76-
def __hash__(self):
78+
def __hash__(self) -> int:
7779
return hash(str(self))
7880

7981
def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
@@ -96,7 +98,7 @@ def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
9698
return False
9799
return True
98100

99-
def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig"):
101+
def compatible_with_a_unified_scale_linked_qconfig(self, linked_qconfig: "QuantizerConfig") -> bool:
100102
"""
101103
For two configs to be compatible in a unified scale scenario, all of their fundamental parameters
102104
must be aligned.
@@ -155,7 +157,12 @@ class QuantizerSpec:
155157
"""
156158

157159
def __init__(
158-
self, num_bits: int, mode: QuantizationScheme, signedness_to_force: bool, narrow_range: bool, half_range: bool
160+
self,
161+
num_bits: int,
162+
mode: QuantizationScheme,
163+
signedness_to_force: Optional[bool],
164+
narrow_range: Optional[bool],
165+
half_range: bool,
159166
):
160167
"""
161168
:param num_bits: Bitwidth of the quantization.
@@ -174,7 +181,9 @@ def __init__(
174181
self.narrow_range = narrow_range
175182
self.half_range = half_range
176183

177-
def __eq__(self, other: "QuantizerSpec"):
184+
def __eq__(self, other: object) -> bool:
185+
if not isinstance(other, QuantizerSpec):
186+
return False
178187
return self.__dict__ == other.__dict__
179188

180189
@classmethod
@@ -185,7 +194,7 @@ def from_config(cls, qconfig: QuantizerConfig, narrow_range: bool, half_range: b
185194
class QuantizationConstraints:
186195
REF_QCONF_OBJ = QuantizerConfig()
187196

188-
def __init__(self, **kwargs):
197+
def __init__(self, **kwargs: Any) -> None:
189198
"""
190199
Use attribute names of QuantizerConfig as arguments
191200
to set up constraints.
@@ -220,7 +229,7 @@ def get_updated_constraints(self, overriding_constraints: "QuantizationConstrain
220229
return QuantizationConstraints(**new_dict)
221230

222231
@classmethod
223-
def from_config_dict(cls, config_dict: Dict) -> "QuantizationConstraints":
232+
def from_config_dict(cls, config_dict: Dict[str, Any]) -> "QuantizationConstraints":
224233
return cls(
225234
num_bits=config_dict.get("bits"),
226235
mode=config_dict.get("mode"),
@@ -264,19 +273,21 @@ class QuantizerId:
264273
structure.
265274
"""
266275

267-
def get_base(self):
276+
def get_base(self) -> str:
268277
raise NotImplementedError
269278

270279
def get_suffix(self) -> str:
271280
raise NotImplementedError
272281

273-
def __str__(self):
282+
def __str__(self) -> str:
274283
return str(self.get_base()) + self.get_suffix()
275284

276-
def __hash__(self):
285+
def __hash__(self) -> int:
277286
return hash((self.get_base(), self.get_suffix()))
278287

279-
def __eq__(self, other: "QuantizerId"):
288+
def __eq__(self, other: object) -> bool:
289+
if not isinstance(other, QuantizerId):
290+
return False
280291
return (self.get_base() == other.get_base()) and (self.get_suffix() == other.get_suffix())
281292

282293

@@ -299,7 +310,7 @@ class NonWeightQuantizerId(QuantizerId):
299310
ordinary activation, function and input
300311
"""
301312

302-
def __init__(self, target_node_name: NNCFNodeName, input_port_id=None):
313+
def __init__(self, target_node_name: NNCFNodeName, input_port_id: Optional[int] = None):
303314
self.target_node_name = target_node_name
304315
self.input_port_id = input_port_id
305316

@@ -335,7 +346,7 @@ class QuantizationPreset(StrEnum):
335346
PERFORMANCE = "performance"
336347
MIXED = "mixed"
337348

338-
def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict:
349+
def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict[str, str]:
339350
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
340351
return {"mode": QuantizationScheme.ASYMMETRIC}
341352
return {"mode": QuantizationScheme.SYMMETRIC}

nncf/common/stateful_classes_registry.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
# limitations under the License.
1111

1212
import inspect
13-
from typing import Callable, Dict
13+
from typing import Callable, Dict, TypeVar
14+
15+
TObj = TypeVar("TObj", bound=type)
1416

1517

1618
class StatefulClassesRegistry:
@@ -24,15 +26,15 @@ def __init__(self) -> None:
2426
self._name_vs_class_map: Dict[str, type] = {}
2527
self._class_vs_name_map: Dict[type, str] = {}
2628

27-
def register(self, name: str = None) -> Callable[[type], type]:
29+
def register(self, name: str = None) -> Callable[[TObj], TObj]:
2830
"""
2931
Decorator to map class with some name - specified in the argument or name of the class.
3032
3133
:param name: The registration name. By default, it's name of the class.
3234
:return: The inner function for registration.
3335
"""
3436

35-
def decorator(cls: type) -> type:
37+
def decorator(cls: TObj) -> TObj:
3638
registered_name = name if name is not None else cls.__name__
3739

3840
if registered_name in self._name_vs_class_map:
@@ -88,15 +90,15 @@ class CommonStatefulClassesRegistry:
8890
"""
8991

9092
@staticmethod
91-
def register(name: str = None) -> Callable[[type], type]:
93+
def register(name: str = None) -> Callable[[TObj], TObj]:
9294
"""
9395
Decorator to map class with some name - specified in the argument or name of the class.
9496
9597
:param name: The registration name. By default, it's name of the class.
9698
:return: The inner function for registration.
9799
"""
98100

99-
def decorator(cls: type) -> type:
101+
def decorator(cls: TObj) -> TObj:
100102
PT_STATEFUL_CLASSES.register(name)(cls)
101103
TF_STATEFUL_CLASSES.register(name)(cls)
102104
return cls

pyproject.toml

-5
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,8 @@ exclude = [
123123
"nncf/common/quantization/quantizer_propagation/graph.py",
124124
"nncf/common/quantization/quantizer_propagation/grouping.py",
125125
"nncf/common/quantization/quantizer_propagation/solver.py",
126-
"nncf/common/quantization/quantizer_propagation/structs.py",
127-
"nncf/common/quantization/quantizer_propagation/visualizer.py",
128126
"nncf/common/quantization/quantizer_removal.py",
129127
"nncf/common/quantization/quantizer_setup.py",
130-
"nncf/common/quantization/quantizers.py",
131-
"nncf/common/quantization/statistics.py",
132-
"nncf/common/quantization/structs.py",
133128
]
134129

135130
[tool.ruff]

0 commit comments

Comments
 (0)