20
20
from neural_compressor .common import logger
21
21
from neural_compressor .common .base_config import BaseConfig , get_all_config_set_from_config_registry
22
22
from neural_compressor .common .base_tuning import EvaluationFuncWrapper , TuningConfig , init_tuning
23
- from neural_compressor .common .utils import dump_elapsed_time
23
+ from neural_compressor .common .utils import call_counter , dump_elapsed_time
24
24
from neural_compressor .tensorflow .quantization import quantize_model
25
25
from neural_compressor .tensorflow .quantization .config import FRAMEWORK_NAME , StaticQuantConfig
26
26
from neural_compressor .tensorflow .utils import BaseModel , Model , constants
@@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
36
36
37
37
38
38
@dump_elapsed_time ("Pass auto-tune" )
39
+ @call_counter
39
40
def autotune (
40
41
model : Union [str , tf .keras .Model , BaseModel ],
41
42
tune_config : TuningConfig ,
@@ -52,7 +53,7 @@ def autotune(
52
53
baseline : float = eval_func_wrapper .evaluate (model )
53
54
tuning_monitor .set_baseline (baseline )
54
55
tuning_logger .tuning_start ()
55
- for trial_index , quant_config in enumerate (config_loader ):
56
+ for trial_index , quant_config in enumerate (config_loader , 1 ):
56
57
tuning_logger .trial_start (trial_index = trial_index )
57
58
tuning_logger .execution_start ()
58
59
logger .info (quant_config .to_dict ())
@@ -65,8 +66,14 @@ def autotune(
65
66
tuning_logger .trial_end (trial_index )
66
67
if tuning_monitor .need_stop ():
67
68
logger .info ("Stopped tuning." )
68
- best_quant_config : BaseConfig = tuning_monitor .get_best_quant_config ()
69
- best_quant_model = quantize_model (model , quant_config , calib_dataloader , calib_iteration )
69
+ best_trial_record = tuning_monitor .get_best_trial_record ()
70
+ if best_trial_record .trial_index != trial_index :
71
+ logger .info ("Re-quantizing with best quantization config..." )
72
+ del q_model
73
+ best_quant_config : BaseConfig = best_trial_record .quant_config
74
+ best_quant_model = quantize_model (model , best_quant_config , calib_dataloader , calib_iteration )
75
+ else :
76
+ best_quant_model = q_model
70
77
break
71
78
tuning_logger .tuning_end ()
72
79
return best_quant_model
0 commit comments