@@ -86,7 +86,7 @@ def __init__(self, framework_specific_info):
86
86
cfg_yaml_name = "{}.yaml" .format (self .__class__ .__name__ [: - len ("Adaptor" )].lower ())
87
87
self .itex_mode = self .backend == "itex" or cfg_yaml_name == "tensorflow_itex.yaml"
88
88
89
- if self .itex_mode :
89
+ if self .itex_mode : # pragma: no cover
90
90
self ._check_itex ()
91
91
92
92
self .query_handler = TensorflowQuery (
@@ -109,7 +109,7 @@ def __init__(self, framework_specific_info):
109
109
110
110
self ._last_dequantize_ops = None
111
111
112
- def _check_itex (self ):
112
+ def _check_itex (self ): # pragma: no cover
113
113
try :
114
114
import intel_extension_for_tensorflow
115
115
except :
@@ -133,7 +133,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):
133
133
134
134
invalid_op_names = [i for i in self .quantize_config ["op_wise_config" ] if i not in dispatched_op_names ]
135
135
136
- for op_name in invalid_op_names :
136
+ for op_name in invalid_op_names : # pragma: no cover
137
137
self .quantize_config ["op_wise_config" ].pop (op_name )
138
138
139
139
for each_op_info in tuning_cfg ["op" ]:
@@ -144,7 +144,7 @@ def _tuning_cfg_to_fw(self, tuning_cfg):
144
144
self .quantize_config ["op_wise_config" ].pop (op_name )
145
145
if tuning_cfg ["op" ][each_op_info ]["activation" ]["dtype" ] == "fp32" :
146
146
fp32_ops .append (op_name )
147
- if tuning_cfg ["op" ][each_op_info ]["activation" ]["dtype" ] == "bf16" :
147
+ if tuning_cfg ["op" ][each_op_info ]["activation" ]["dtype" ] == "bf16" : # pragma: no cover
148
148
bf16_ops .append (op_name )
149
149
continue
150
150
@@ -342,7 +342,7 @@ def _dump_model_op_stats(self, model_graphdef):
342
342
res [origin_op_type ]["INT8" ] += 1
343
343
344
344
if i .op in fp32_op_list :
345
- if "T" not in i .attr and i .op != "Cast" :
345
+ if "T" not in i .attr and i .op != "Cast" : # pragma: no cover
346
346
continue
347
347
if i .op == "Cast" :
348
348
if i .attr ["DstT" ].type == dtypes .bfloat16 :
@@ -432,7 +432,7 @@ def _query_quantizable_ops(self, matched_nodes):
432
432
) and len (first_conv_or_matmul_node ) == 0 :
433
433
first_conv_or_matmul_node .append ((node_name , self .unify_op_type_mapping [node_op ]))
434
434
self .recipes_ops ["first_conv_or_matmul_quantization" ] = first_conv_or_matmul_node
435
- if exclude_first_quantizable_op and (
435
+ if exclude_first_quantizable_op and ( # pragma: no cover
436
436
self .unify_op_type_mapping [node_op ].find ("conv2d" ) != - 1
437
437
or self .unify_op_type_mapping [node_op ].find ("matmul" ) != - 1
438
438
):
@@ -493,7 +493,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
493
493
concat_nodes = g .query_fusion_pattern_nodes ([["ConcatV2" ]])
494
494
for i in concat_nodes :
495
495
concat_node_name = i [0 ]
496
- if concat_node_name not in target_concat_nodes :
496
+ if concat_node_name not in target_concat_nodes : # pragma: no cover
497
497
continue
498
498
input_positive_status = []
499
499
for index in range (graph_info [concat_node_name ].node .attr ["N" ].i ):
@@ -507,7 +507,7 @@ def _filter_unquantizable_concat(self, matched_nodes):
507
507
else :
508
508
positive_input = g .has_positive_input (each_input_node .name )
509
509
input_positive_status .append (positive_input )
510
- if not any (input_positive_status ):
510
+ if not any (input_positive_status ): # pragma: no cover
511
511
matched_nodes .remove (i )
512
512
513
513
def _filter_unquantizable_concat_performance_only (self , matched_nodes ):
@@ -522,7 +522,7 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
522
522
concat_nodes = g .query_fusion_pattern_nodes ([["ConcatV2" ]])
523
523
for i in concat_nodes :
524
524
concat_node_name = i [0 ]
525
- if concat_node_name not in target_concat_nodes :
525
+ if concat_node_name not in target_concat_nodes : # pragma: no cover
526
526
continue
527
527
input_positive_status = []
528
528
control_flow = False
@@ -531,9 +531,9 @@ def _filter_unquantizable_concat_performance_only(self, matched_nodes):
531
531
graph_info [concat_node_name ].node .input [index ]
532
532
)
533
533
each_input_node = graph_info [each_input_name ].node
534
- if each_input_node .op in ("Switch" ):
534
+ if each_input_node .op in ("Switch" ): # pragma: no cover
535
535
control_flow = True
536
- if control_flow :
536
+ if control_flow : # pragma: no cover
537
537
matched_nodes .remove (i )
538
538
539
539
def parse_quant_config (self , quant_config , model , calib_iteration ):
@@ -588,7 +588,7 @@ def _query_fw_capability(self, model):
588
588
589
589
def check_match (patterns , input_pattern ):
590
590
for i in patterns :
591
- if input_pattern == [i for i in i .replace ("+" , " " ).strip ().split (" " ) if i ]:
591
+ if input_pattern == [i for i in i .replace ("+" , " " ).strip ().split (" " ) if i ]: # pragma: no cover
592
592
return True
593
593
return False
594
594
@@ -641,7 +641,7 @@ def quantize_input(self, model):
641
641
"""
642
642
scale = None
643
643
# quantize input only support tensorflow version > 2.1.0
644
- if version1_lt_version2 (tf .version .VERSION , "2.1.0" ):
644
+ if version1_lt_version2 (tf .version .VERSION , "2.1.0" ): # pragma: no cover
645
645
logger .warning ("Quantize input needs tensorflow 2.1.0 and newer." )
646
646
return model , scale
647
647
@@ -872,7 +872,7 @@ def precisions(self):
872
872
return self ._precisions
873
873
874
874
@precisions .setter
875
- def precisions (self , precisions ):
875
+ def precisions (self , precisions ): # pragma: no cover
876
876
"""Set precision."""
877
877
if not isinstance (precisions , list ):
878
878
precisions = [precisions ]
@@ -881,7 +881,7 @@ def precisions(self, precisions):
881
881
self ._precisions = precisions
882
882
883
883
@staticmethod
884
- def check_value (name , src , supported_type , supported_value = []):
884
+ def check_value (name , src , supported_type , supported_value = []): # pragma: no cover
885
885
"""Check if the given object is the given supported type and in the given supported value.
886
886
887
887
Example::
@@ -946,7 +946,7 @@ def _get_specified_version_cfg(self, data):
946
946
config = None
947
947
948
948
def _compare (version1 , version2 ):
949
- if parse_version (version1 ) == parse_version (version2 ):
949
+ if parse_version (version1 ) == parse_version (version2 ): # pragma: no cover
950
950
return 0
951
951
elif parse_version (version1 ) < parse_version (version2 ):
952
952
return - 1
@@ -979,7 +979,7 @@ def _compare(version1, version2):
979
979
# convention. Replacing them with dot for version comparison.
980
980
sorted_list = [i .replace ("-up" , "." ) for i in sorted_list ]
981
981
sorted_list = sorted (sorted_list , key = cmp_to_key (_compare ), reverse = True )
982
- else :
982
+ else : # pragma: no cover
983
983
assert isinstance (sorted_list , str )
984
984
sorted_list = list (sorted_list .replace ("-up" , "." ).split ())
985
985
for i in sorted_list :
@@ -1025,7 +1025,7 @@ def _one_shot_query(self):
1025
1025
def _update_cfg_with_usr_definition (self ):
1026
1026
"""Add user defined precision configuration."""
1027
1027
tensorflow_config = TensorFlowConfig ()
1028
- if tensorflow_config .precisions is not None :
1028
+ if tensorflow_config .precisions is not None : # pragma: no cover
1029
1029
self .cur_config ["precisions" ]["names" ] = "," .join (tensorflow_config .precisions )
1030
1030
1031
1031
def get_version (self ):
@@ -1288,7 +1288,7 @@ def get_fuse_patterns(self):
1288
1288
elif version1_gte_version2 (tf .version .VERSION , "2.1.0" ):
1289
1289
patterns ["int8" ] = tf_int8_pattern_list
1290
1290
patterns ["uint8" ] = tf_uint8_pattern_list
1291
- if self .itex_mode :
1291
+ if self .itex_mode : # pragma: no cover
1292
1292
patterns ["int8" ].append ("FusedBatchNormV3 + Relu" )
1293
1293
patterns ["int8" ].append ("FusedBatchNormV3 + LeakyRelu" )
1294
1294
elif version1_eq_version2 (tf .version .VERSION , "1.15.0-up3" ): # pragma: no cover
@@ -1340,23 +1340,23 @@ def get_op_types_by_precision(self, precision):
1340
1340
tf .version .VERSION , "1.15.0-up3"
1341
1341
):
1342
1342
return ["Conv2D" , "MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" ]
1343
- return ["MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" ]
1343
+ return ["MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" ] # pragma: no cover
1344
1344
if precision == "uint8" :
1345
1345
if tf .version .VERSION in spr_base_verions :
1346
1346
return [key for key in self .cur_config ["int8" ][self .quant_mode ].keys () if "Norm" not in key ]
1347
1347
if version1_gte_version2 (tf .version .VERSION , "2.1.0" ) or version1_eq_version2 (
1348
1348
tf .version .VERSION , "1.15.0-up3"
1349
1349
):
1350
1350
return ["Conv2D" , "MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" , "DepthwiseConv2dNative" ]
1351
- return ["Conv2D" , "MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" ]
1351
+ return ["Conv2D" , "MatMul" , "ConcatV2" , "MaxPool" , "AvgPool" ] # pragma: no cover
1352
1352
if precision == "bf16" :
1353
1353
if tf .version .VERSION in spr_base_verions :
1354
1354
return self .cur_config [precision ]
1355
1355
if version1_gte_version2 (tf .version .VERSION , "2.1.0" ) or version1_eq_version2 (
1356
1356
tf .version .VERSION , "1.15.0-up3"
1357
1357
):
1358
1358
return self .cur_config [precision ]
1359
- return []
1359
+ return [] # pragma: no cover
1360
1360
1361
1361
def get_mixed_precision_combination (self ):
1362
1362
"""Get the valid mixed precisions.
0 commit comments