8
8
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
+ from __future__ import annotations
11
12
12
- from typing import Dict , List , Optional , Tuple , Union
13
+ from typing import Any , Dict , List , Optional , Tuple , Union
13
14
14
15
from nncf .common .graph .utils import get_reduction_axes
15
16
from nncf .common .initialization .dataloader import NNCFDataLoader
@@ -26,7 +27,12 @@ class RangeInitConfig:
26
27
parameters.
27
28
"""
28
29
29
- def __init__ (self , init_type : str , num_init_samples : int , init_type_specific_params : Dict = None ):
30
+ def __init__ (
31
+ self ,
32
+ init_type : str ,
33
+ num_init_samples : int ,
34
+ init_type_specific_params : Optional [Dict [str , int ]] = None ,
35
+ ):
30
36
"""
31
37
Initializes the quantization range initialization parameters.
32
38
@@ -43,11 +49,11 @@ def __init__(self, init_type: str, num_init_samples: int, init_type_specific_par
43
49
if self .init_type_specific_params is None :
44
50
self .init_type_specific_params = {}
45
51
46
- def __eq__ (self , other ) :
52
+ def __eq__ (self , other : object ) -> bool :
47
53
return self .__dict__ == other .__dict__
48
54
49
55
@classmethod
50
- def from_dict (cls , dct : Dict ) -> " RangeInitConfig" :
56
+ def from_dict (cls , dct : Dict [ str , Any ] ) -> RangeInitConfig :
51
57
num_init_samples = dct .get ("num_init_samples" , NUM_INIT_SAMPLES )
52
58
if num_init_samples < 0 :
53
59
raise ValueError ("Number of initialization samples must be >= 0" )
@@ -94,10 +100,10 @@ def __init__(
94
100
self .target_group = target_quantizer_group
95
101
96
102
@classmethod
97
- def from_dict (cls , dct : Dict ) -> " PerLayerRangeInitConfig" :
103
+ def from_dict (cls , dct : Dict [ str , Any ] ) -> PerLayerRangeInitConfig :
98
104
base_config = RangeInitConfig .from_dict (dct )
99
105
100
- def get_list (dct : Dict , attr_name : str ) -> Optional [List [str ]]:
106
+ def get_list (dct : Dict [ str , Any ] , attr_name : str ) -> Optional [List [str ]]:
101
107
str_or_list = dct .get (attr_name )
102
108
if str_or_list is None :
103
109
return None
@@ -185,7 +191,7 @@ def is_per_channel(self) -> bool:
185
191
"""
186
192
return self ._is_per_channel
187
193
188
- def use_per_sample_stats (self , per_sample_stats ) -> bool :
194
+ def use_per_sample_stats (self , per_sample_stats : bool ) -> bool :
189
195
"""
190
196
For activations, if per_sample_stats is True, statistics will be collected per-sample.
191
197
For weights statistics are always collected per-batch.
@@ -213,7 +219,7 @@ def _get_reduction_axes(
213
219
shape_to_reduce : Union [Tuple [int , ...], List [int ]],
214
220
quantization_axes : Union [Tuple [int , ...], List [int ]],
215
221
aggregation_axes : Union [Tuple [int , ...], List [int ]],
216
- ):
222
+ ) -> Tuple [ int , ...] :
217
223
"""
218
224
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
219
225
from these axes only tensor related axes should be used for reducer.
@@ -225,7 +231,7 @@ def _get_reduction_axes(
225
231
"""
226
232
axes_to_keep = set (el - 1 for el in aggregation_axes if el != 0 )
227
233
axes_to_keep .update (quantization_axes )
228
- return get_reduction_axes (axes_to_keep , shape_to_reduce )
234
+ return get_reduction_axes (list ( axes_to_keep ) , shape_to_reduce )
229
235
230
236
def _get_aggregation_axes (self , batchwise_statistics : bool ) -> Tuple [int , ...]:
231
237
"""
0 commit comments