Skip to content

Commit 0931072

Browse files
[mypy] nncf/common/quantization (part 2) (#3197)
### Changes Enable mypy for part of files in nncf/common/quantization
1 parent d1b5229 commit 0931072

File tree

9 files changed

+208
-181
lines changed

9 files changed

+208
-181
lines changed

nncf/common/quantization/initialization/range.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
from __future__ import annotations
1112

12-
from typing import Dict, List, Optional, Tuple, Union
13+
from typing import Any, Dict, List, Optional, Tuple, Union
1314

1415
from nncf.common.graph.utils import get_reduction_axes
1516
from nncf.common.initialization.dataloader import NNCFDataLoader
@@ -26,7 +27,12 @@ class RangeInitConfig:
2627
parameters.
2728
"""
2829

29-
def __init__(self, init_type: str, num_init_samples: int, init_type_specific_params: Dict = None):
30+
def __init__(
31+
self,
32+
init_type: str,
33+
num_init_samples: int,
34+
init_type_specific_params: Optional[Dict[str, int]] = None,
35+
):
3036
"""
3137
Initializes the quantization range initialization parameters.
3238
@@ -43,11 +49,11 @@ def __init__(self, init_type: str, num_init_samples: int, init_type_specific_par
4349
if self.init_type_specific_params is None:
4450
self.init_type_specific_params = {}
4551

46-
def __eq__(self, other):
52+
def __eq__(self, other: object) -> bool:
4753
return self.__dict__ == other.__dict__
4854

4955
@classmethod
50-
def from_dict(cls, dct: Dict) -> "RangeInitConfig":
56+
def from_dict(cls, dct: Dict[str, Any]) -> RangeInitConfig:
5157
num_init_samples = dct.get("num_init_samples", NUM_INIT_SAMPLES)
5258
if num_init_samples < 0:
5359
raise ValueError("Number of initialization samples must be >= 0")
@@ -94,10 +100,10 @@ def __init__(
94100
self.target_group = target_quantizer_group
95101

96102
@classmethod
97-
def from_dict(cls, dct: Dict) -> "PerLayerRangeInitConfig":
103+
def from_dict(cls, dct: Dict[str, Any]) -> PerLayerRangeInitConfig:
98104
base_config = RangeInitConfig.from_dict(dct)
99105

100-
def get_list(dct: Dict, attr_name: str) -> Optional[List[str]]:
106+
def get_list(dct: Dict[str, Any], attr_name: str) -> Optional[List[str]]:
101107
str_or_list = dct.get(attr_name)
102108
if str_or_list is None:
103109
return None
@@ -185,7 +191,7 @@ def is_per_channel(self) -> bool:
185191
"""
186192
return self._is_per_channel
187193

188-
def use_per_sample_stats(self, per_sample_stats) -> bool:
194+
def use_per_sample_stats(self, per_sample_stats: bool) -> bool:
189195
"""
190196
For activations, if per_sample_stats is True, statistics will be collected per-sample.
191197
For weights statistics are always collected per-batch.
@@ -213,7 +219,7 @@ def _get_reduction_axes(
213219
shape_to_reduce: Union[Tuple[int, ...], List[int]],
214220
quantization_axes: Union[Tuple[int, ...], List[int]],
215221
aggregation_axes: Union[Tuple[int, ...], List[int]],
216-
):
222+
) -> Tuple[int, ...]:
217223
"""
218224
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
219225
from these axes only tensor related axes should be used for reducer.
@@ -225,7 +231,7 @@ def _get_reduction_axes(
225231
"""
226232
axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0)
227233
axes_to_keep.update(quantization_axes)
228-
return get_reduction_axes(axes_to_keep, shape_to_reduce)
234+
return get_reduction_axes(list(axes_to_keep), shape_to_reduce)
229235

230236
def _get_aggregation_axes(self, batchwise_statistics: bool) -> Tuple[int, ...]:
231237
"""

nncf/common/quantization/quantizer_propagation/graph.py

+83-63
Large diffs are not rendered by default.

nncf/common/quantization/quantizer_propagation/grouping.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class UnifiedScalePropagatingQuantizerGroupManager:
2121
quantized model.
2222
"""
2323

24-
def __init__(self):
24+
def __init__(self) -> None:
2525
self._next_gid = 0
2626
self._group_vs_prop_quants_dict: Dict[int, Set[PropagatingQuantizer]] = {}
2727

@@ -46,7 +46,7 @@ def register_group(self, prop_quants: Set[PropagatingQuantizer]) -> int:
4646
self._group_vs_prop_quants_dict[gid] = prop_quants
4747
return gid
4848

49-
def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer):
49+
def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer) -> None:
5050
"""
5151
Adds a propagating quantizer to an already existing group.
5252
@@ -62,7 +62,7 @@ def add_to_group(self, target_gid: int, prop_quant: PropagatingQuantizer):
6262
)
6363
self._group_vs_prop_quants_dict[target_gid].add(prop_quant)
6464

65-
def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer):
65+
def remove_from_group(self, group: int, prop_quant: PropagatingQuantizer) -> None:
6666
"""
6767
Removes a propagating quantizer from a group.
6868
@@ -91,7 +91,7 @@ def get_group_id_by_propagating_quantizer_id(self, requested_pqid: int) -> Optio
9191
return gid
9292
return None
9393

94-
def merge_groups(self, merge_to_gid: int, merge_from_gid: int):
94+
def merge_groups(self, merge_to_gid: int, merge_from_gid: int) -> None:
9595
"""
9696
Merges two groups into a single one. The `merge_to_gid` group retains its group ID.
9797
@@ -110,11 +110,11 @@ class QuantizersWaitingForMergeManager:
110110
and corresponding node keys.
111111
"""
112112

113-
def __init__(self):
113+
def __init__(self) -> None:
114114
self._branching_node_keys_vs_quantizers_waiting_for_merge: Dict[str, Set[PropagatingQuantizer]] = {}
115115
self._quantizers_vs_branching_node_keys: Dict[PropagatingQuantizer, str] = {}
116116

117-
def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str):
117+
def add_propagating_quantizer_to_wait_on_node_key(self, pq: PropagatingQuantizer, branching_node_key: str) -> None:
118118
"""
119119
Registers a propagating quantizer as "waiting" on a node in QuantizerPropagationStateGraph.
120120
@@ -146,10 +146,10 @@ def get_waiting_quantizers_for_branching_node_key(self, node_key: str) -> Set[Pr
146146
"""
147147
return self._branching_node_keys_vs_quantizers_waiting_for_merge[node_key]
148148

149-
def __contains__(self, item: PropagatingQuantizer):
149+
def __contains__(self, item: PropagatingQuantizer) -> bool:
150150
return item in self._quantizers_vs_branching_node_keys
151151

152-
def resolve_merged_node(self, branching_node_key: str):
152+
def resolve_merged_node(self, branching_node_key: str) -> None:
153153
"""
154154
De-registers any quantizers that were previously registered to be "waiting" on a given node key.
155155
:param branching_node_key: The node key in QuantizerPropagationStateGraph that some propagating

0 commit comments

Comments
 (0)