Skip to content

Commit 190006d

Browse files
[mypy] nncf/common/pruning (#3198)
### Changes Enable mypy for nncf/common/pruning
1 parent 0931072 commit 190006d

13 files changed

+181
-149
lines changed

nncf/common/pruning/clusterization.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,26 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Dict, Generic, Hashable, List, TypeVar
12+
from typing import Callable, Dict, Generic, Hashable, List, Optional, TypeVar
1313

1414
T = TypeVar("T")
1515

1616

1717
class Cluster(Generic[T]):
1818
"""
19-
Represents element of Сlusterization. Groups together elements.
19+
Represents element of Clusterization. Groups together elements.
2020
"""
2121

22-
def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]):
22+
def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]) -> None:
2323
self.id = cluster_id
2424
self.elements = list(elements)
2525
self.importance = max(nodes_orders)
2626

27-
def clean_cluster(self):
27+
def clean_cluster(self) -> None:
2828
self.elements = []
2929
self.importance = 0
3030

31-
def add_elements(self, elements: List[T], importance: int):
31+
def add_elements(self, elements: List[T], importance: int) -> None:
3232
self.elements.extend(elements)
3333
self.importance = max(self.importance, importance)
3434

@@ -39,7 +39,7 @@ class Clusterization(Generic[T]):
3939
delete existing one or merge existing clusters.
4040
"""
4141

42-
def __init__(self, id_fn: Callable[[T], Hashable] = None):
42+
def __init__(self, id_fn: Optional[Callable[[T], Hashable]] = None) -> None:
4343
self.clusters: Dict[int, Cluster[T]] = {}
4444
self._element_to_cluster: Dict[Hashable, int] = {}
4545
if id_fn is None:
@@ -78,7 +78,7 @@ def is_node_in_clusterization(self, node_id: int) -> bool:
7878
"""
7979
return node_id in self._element_to_cluster
8080

81-
def add_cluster(self, cluster: Cluster[T]):
81+
def add_cluster(self, cluster: Cluster[T]) -> None:
8282
"""
8383
Adds provided cluster to clusterization.
8484
@@ -89,9 +89,9 @@ def add_cluster(self, cluster: Cluster[T]):
8989
raise IndexError("Cluster with index = {} already exist".format(cluster_id))
9090
self.clusters[cluster_id] = cluster
9191
for elt in cluster.elements:
92-
self._element_to_cluster[self._id_fn(elt)] = cluster_id
92+
self._element_to_cluster[self._id_fn(elt)] = cluster_id # type: ignore[no-untyped-call]
9393

94-
def delete_cluster(self, cluster_id: int):
94+
def delete_cluster(self, cluster_id: int) -> None:
9595
"""
9696
Removes cluster with `cluster_id` from clusterization.
9797
@@ -100,7 +100,7 @@ def delete_cluster(self, cluster_id: int):
100100
if cluster_id not in self.clusters:
101101
raise IndexError("No cluster with index = {} to delete".format(cluster_id))
102102
for elt in self.clusters[cluster_id].elements:
103-
node_id = self._id_fn(elt)
103+
node_id = self._id_fn(elt) # type: ignore[no-untyped-call]
104104
self._element_to_cluster.pop(node_id)
105105
self.clusters.pop(cluster_id)
106106

@@ -123,7 +123,7 @@ def get_all_nodes(self) -> List[T]:
123123
all_elements.extend(cluster.elements)
124124
return all_elements
125125

126-
def merge_clusters(self, first_id: int, second_id: int):
126+
def merge_clusters(self, first_id: int, second_id: int) -> None:
127127
"""
128128
Merges two clusters with provided ids.
129129
@@ -135,15 +135,15 @@ def merge_clusters(self, first_id: int, second_id: int):
135135
if cluster_1.importance > cluster_2.importance:
136136
cluster_1.add_elements(cluster_2.elements, cluster_2.importance)
137137
for elt in cluster_2.elements:
138-
self._element_to_cluster[self._id_fn(elt)] = first_id
138+
self._element_to_cluster[self._id_fn(elt)] = first_id # type: ignore[no-untyped-call]
139139
self.clusters.pop(second_id)
140140
else:
141141
cluster_2.add_elements(cluster_1.elements, cluster_1.importance)
142142
for elt in cluster_1.elements:
143-
self._element_to_cluster[self._id_fn(elt)] = second_id
143+
self._element_to_cluster[self._id_fn(elt)] = second_id # type: ignore[no-untyped-call]
144144
self.clusters.pop(first_id)
145145

146-
def merge_list_of_clusters(self, clusters: List[int]):
146+
def merge_list_of_clusters(self, clusters: List[int]) -> None:
147147
"""
148148
Merges provided clusters.
149149

nncf/common/pruning/mask_propagation.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, List, Optional, Type
12+
from typing import Dict, List, Optional, Set, Type
1313

1414
from nncf.common.graph import NNCFGraph
1515
from nncf.common.pruning.operations import BasePruningOp
@@ -38,7 +38,7 @@ def __init__(
3838
graph: NNCFGraph,
3939
pruning_operator_metatypes: PruningOperationsMetatypeRegistry,
4040
tensor_processor: Optional[Type[NNCFPruningBaseTensorProcessor]] = None,
41-
):
41+
) -> None:
4242
"""
4343
Initializes MaskPropagationAlgorithm.
4444
@@ -51,7 +51,7 @@ def __init__(
5151
self._pruning_operator_metatypes = pruning_operator_metatypes
5252
self._tensor_processor = tensor_processor
5353

54-
def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
54+
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
5555
"""
5656
Returns class of metaop that corresponds to `type_name` type.
5757
@@ -63,14 +63,14 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
6363
cls = self._pruning_operator_metatypes.registry_dict["stop_propagation_ops"]
6464
return cls
6565

66-
def mask_propagation(self):
66+
def mask_propagation(self) -> None:
6767
"""
6868
Mask propagation in graph:
6969
to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order.
7070
"""
7171
for node in self._graph.topological_sort():
7272
cls = self.get_meta_operation_by_type_name(node.node_type)
73-
cls.mask_propagation(node, self._graph, self._tensor_processor)
73+
cls.mask_propagation(node, self._graph, self._tensor_processor) # type: ignore
7474

7575
def symbolic_mask_propagation(
7676
self, prunable_layers_types: List[str], can_prune_after_analysis: Dict[int, PruningAnalysisDecision]
@@ -96,7 +96,7 @@ def symbolic_mask_propagation(
9696
"""
9797

9898
can_be_closing_convs = self._get_can_closing_convs(prunable_layers_types)
99-
can_prune_by_dim = {k: None for k in can_be_closing_convs}
99+
can_prune_by_dim: Dict[int, PruningAnalysisDecision] = {k: None for k in can_be_closing_convs} # type: ignore
100100
for node in self._graph.topological_sort():
101101
if node.node_id in can_be_closing_convs and can_prune_after_analysis[node.node_id]:
102102
# Set output mask
@@ -109,15 +109,16 @@ def symbolic_mask_propagation(
109109
input_masks = get_input_masks(node, self._graph)
110110
if any(input_masks):
111111
assert len(input_masks) == 1
112-
input_mask: SymbolicMask = input_masks[0]
112+
input_mask = input_masks[0]
113+
assert isinstance(input_mask, SymbolicMask)
113114

114115
for producer in input_mask.mask_producers:
115116
previously_dims_equal = (
116117
True if can_prune_by_dim[producer.id] is None else can_prune_by_dim[producer.id]
117118
)
118119

119120
is_dims_equal = get_input_channels(node) == input_mask.shape[0]
120-
decision = previously_dims_equal and is_dims_equal
121+
decision = bool(previously_dims_equal and is_dims_equal)
121122
can_prune_by_dim[producer.id] = PruningAnalysisDecision(
122123
decision, PruningAnalysisReason.DIMENSION_MISMATCH
123124
)
@@ -130,7 +131,7 @@ def symbolic_mask_propagation(
130131
can_prune_by_dim[producer.id] = PruningAnalysisDecision(False, PruningAnalysisReason.LAST_CONV)
131132
# Update decision for nodes which
132133
# have no closing convolution
133-
convs_without_closing_conv = {}
134+
convs_without_closing_conv: Dict[int, PruningAnalysisDecision] = {}
134135
for k, v in can_prune_by_dim.items():
135136
if v is None:
136137
convs_without_closing_conv[k] = PruningAnalysisDecision(
@@ -144,8 +145,8 @@ def symbolic_mask_propagation(
144145

145146
return can_prune_by_dim
146147

147-
def _get_can_closing_convs(self, prunable_layers_types) -> Dict:
148-
retval = set()
148+
def _get_can_closing_convs(self, prunable_layers_types: List[str]) -> Set[int]:
149+
retval: Set[int] = set()
149150
for node in self._graph.get_all_nodes():
150151
if node.node_type in prunable_layers_types and not (
151152
is_grouped_conv(node) or is_batched_linear(node, self._graph)

nncf/common/pruning/model_analysis.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, List
12+
from typing import Dict, List, Optional, Type, cast
1313

1414
from nncf.common.graph import NNCFGraph
1515
from nncf.common.graph import NNCFNode
@@ -23,14 +23,16 @@
2323
from nncf.common.pruning.utils import is_prunable_depthwise_conv
2424

2525

26-
def get_position(nodes_list: List[NNCFNode], idx: int):
26+
def get_position(nodes_list: List[NNCFNode], idx: int) -> Optional[int]:
2727
for i, node in enumerate(nodes_list):
2828
if node.node_id == idx:
2929
return i
3030
return None
3131

3232

33-
def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization):
33+
def merge_clusters_for_nodes(
34+
nodes_to_merge: List[NNCFNode], clusterization: Clusterization # type:ignore[type-arg]
35+
) -> None:
3436
"""
3537
Merges clusters to which nodes from nodes_to_merge belongs.
3638
@@ -75,7 +77,7 @@ def cluster_special_ops(
7577
# 0. Initially all nodes is a separate clusters
7678
clusterization = Clusterization[NNCFNode](lambda x: x.node_id)
7779
for i, node in enumerate(all_special_nodes):
78-
cluster = Cluster[NNCFNode](i, [node], [get_position(topologically_sorted_nodes, node.node_id)])
80+
cluster = Cluster[NNCFNode](i, [node], [get_position(topologically_sorted_nodes, node.node_id)]) # type: ignore
7981
clusterization.add_cluster(cluster)
8082

8183
for node in topologically_sorted_nodes:
@@ -125,7 +127,9 @@ def __init__(
125127
self._pruning_operator_metatypes = pruning_operator_metatypes
126128
self._prune_operations_types = prune_operations_types
127129
pruning_op_metatypes_dict = self._pruning_operator_metatypes.registry_dict
128-
self._stop_propagation_op_metatype = pruning_op_metatypes_dict["stop_propagation_ops"]
130+
self._stop_propagation_op_metatype = cast(
131+
Type[BasePruningOp], pruning_op_metatypes_dict["stop_propagation_ops"]
132+
)
129133
self._concat_op_metatype = pruning_op_metatypes_dict["concat"]
130134

131135
self.can_prune = {idx: True for idx in self.graph.get_all_node_ids()}
@@ -151,7 +155,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool:
151155
"""
152156
return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases()
153157

154-
def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
158+
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
155159
"""
156160
Returns class of metaop that corresponds to `type_name` type.
157161
@@ -162,7 +166,7 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
162166
cls = self._stop_propagation_op_metatype
163167
return cls
164168

165-
def propagate_can_prune_attr_up(self):
169+
def propagate_can_prune_attr_up(self) -> None:
166170
"""
167171
Propagating can_prune attribute in reversed topological order.
168172
This attribute depends on accept_pruned_input and can_prune attributes of output nodes.
@@ -181,7 +185,7 @@ def propagate_can_prune_attr_up(self):
181185
)
182186
self.can_prune[node.node_id] = outputs_accept_pruned_input and outputs_will_be_pruned
183187

184-
def propagate_can_prune_attr_down(self):
188+
def propagate_can_prune_attr_down(self) -> None:
185189
"""
186190
Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned
187191
branches.
@@ -199,7 +203,7 @@ def propagate_can_prune_attr_down(self):
199203
):
200204
self.can_prune[node.node_id] = can_prune
201205

202-
def set_accept_pruned_input_attr(self):
206+
def set_accept_pruned_input_attr(self) -> None:
203207
for nncf_node in self.graph.get_all_nodes():
204208
cls = self.get_meta_operation_by_type_name(nncf_node.node_type)
205209
self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node)

nncf/common/pruning/node_selector.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
102102

103103
# 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops)
104104
for i, cluster in enumerate(special_ops_clusterization.get_all_clusters()):
105-
all_pruned_inputs = {}
106-
clusters_to_merge = []
105+
all_pruned_inputs: Dict[int, NNCFNode] = {}
106+
clusters_to_merge: List[int] = []
107107

108108
for node in cluster.elements:
109109
sources = get_sources_of_node(node, graph, self._prune_operations_types)
@@ -116,7 +116,7 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
116116
all_pruned_inputs[source_node.node_id] = source_node
117117

118118
if all_pruned_inputs:
119-
cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(), all_pruned_inputs.keys())
119+
cluster = Cluster[NNCFNode](i, list(all_pruned_inputs.values()), list(all_pruned_inputs.keys()))
120120
clusters_to_merge.append(cluster.id)
121121
pruned_nodes_clusterization.add_cluster(cluster)
122122

@@ -202,7 +202,7 @@ def _get_multiforward_nodes(self, graph: NNCFGraph) -> List[List[NNCFNode]]:
202202
def _pruning_dimensions_analysis(
203203
self,
204204
graph: NNCFGraph,
205-
pruned_nodes_clusterization: Clusterization,
205+
pruned_nodes_clusterization: Clusterization, # type: ignore[type-arg]
206206
can_prune_after_check: Dict[int, PruningAnalysisDecision],
207207
) -> Dict[int, PruningAnalysisDecision]:
208208
"""
@@ -251,7 +251,7 @@ def _check_all_closing_nodes_are_feasible(
251251
return can_prune_updated
252252

253253
def _check_internal_groups_dim(
254-
self, pruned_nodes_clusterization: Clusterization
254+
self, pruned_nodes_clusterization: Clusterization # type: ignore[type-arg]
255255
) -> Dict[int, PruningAnalysisDecision]:
256256
"""
257257
Checks pruning dimensions of all nodes in each cluster group are equal and
@@ -278,7 +278,7 @@ def _check_internal_groups_dim(
278278
def _should_prune_groups_analysis(
279279
self,
280280
graph: NNCFGraph,
281-
pruned_nodes_clusterization: Clusterization,
281+
pruned_nodes_clusterization: Clusterization, # type: ignore[type-arg]
282282
can_prune: Dict[int, PruningAnalysisDecision],
283283
) -> Dict[int, PruningAnalysisDecision]:
284284
"""
@@ -312,7 +312,7 @@ def _should_prune_groups_analysis(
312312
return can_prune_updated
313313

314314
def _filter_groups(
315-
self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision]
315+
self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision] # type: ignore[type-arg]
316316
) -> None:
317317
"""
318318
Check whether all nodes in group can be pruned based on user-defined constraints and

0 commit comments

Comments
 (0)