Skip to content

Commit 5a0374e

Browse files
authored
Enhance autotune to return the best q_model directly (#1875)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 90fb431 commit 5a0374e

File tree

7 files changed

+99
-11
lines changed

7 files changed

+99
-11
lines changed

neural_compressor/common/base_tuning.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,17 @@ def set_baseline(self, baseline: float):
336336
def get_number_of_trials(self):
337337
return len(self.tuning_history)
338338

339-
def get_best_quant_config(self) -> BaseConfig:
339+
def get_best_trial_record(self) -> _TrialRecord:
340340
assert self.get_number_of_trials() > 0, "No trial record in tuning monitor."
341341
# Put the record with a higher score at the beginning
342342
sorted_trials_records: List[_TrialRecord] = sorted(
343343
self.tuning_history, key=lambda x: x.trial_result, reverse=True
344344
)
345-
return sorted_trials_records[0].quant_config
345+
return sorted_trials_records[0]
346+
347+
def get_best_quant_config(self) -> BaseConfig:
348+
best_trial_record = self.get_best_trial_record()
349+
return best_trial_record.quant_config
346350

347351
def need_stop(self) -> bool:
348352
"""Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.

neural_compressor/common/utils/utility.py

+15
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18+
import collections
1819
import importlib
1920
import subprocess
2021
import time
22+
from typing import Dict
2123

2224
import cpuinfo
2325
import psutil
@@ -35,6 +37,7 @@
3537
"LazyImport",
3638
"CpuInfo",
3739
"default_tuning_logger",
40+
"call_counter",
3841
]
3942

4043

@@ -225,3 +228,15 @@ def inner_wrapper(*args, **kwargs):
225228
return inner_wrapper
226229

227230
return log_process_wrapper
231+
232+
233+
# decorator for recording number of times a function is called
234+
FUNC_CALL_COUNTS: Dict[str, int] = collections.defaultdict(int)
235+
236+
237+
def call_counter(func):
238+
def wrapper(*args, **kwargs):
239+
FUNC_CALL_COUNTS[func.__name__] += 1
240+
return func(*args, **kwargs)
241+
242+
return wrapper

neural_compressor/tensorflow/quantization/autotune.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from neural_compressor.common import logger
2121
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
2222
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
23-
from neural_compressor.common.utils import dump_elapsed_time
23+
from neural_compressor.common.utils import call_counter, dump_elapsed_time
2424
from neural_compressor.tensorflow.quantization import quantize_model
2525
from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME, StaticQuantConfig
2626
from neural_compressor.tensorflow.utils import BaseModel, Model, constants
@@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
3636

3737

3838
@dump_elapsed_time("Pass auto-tune")
39+
@call_counter
3940
def autotune(
4041
model: Union[str, tf.keras.Model, BaseModel],
4142
tune_config: TuningConfig,
@@ -52,7 +53,7 @@ def autotune(
5253
baseline: float = eval_func_wrapper.evaluate(model)
5354
tuning_monitor.set_baseline(baseline)
5455
tuning_logger.tuning_start()
55-
for trial_index, quant_config in enumerate(config_loader):
56+
for trial_index, quant_config in enumerate(config_loader, 1):
5657
tuning_logger.trial_start(trial_index=trial_index)
5758
tuning_logger.execution_start()
5859
logger.info(quant_config.to_dict())
@@ -65,8 +66,14 @@ def autotune(
6566
tuning_logger.trial_end(trial_index)
6667
if tuning_monitor.need_stop():
6768
logger.info("Stopped tuning.")
68-
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
69-
best_quant_model = quantize_model(model, quant_config, calib_dataloader, calib_iteration)
69+
best_trial_record = tuning_monitor.get_best_trial_record()
70+
if best_trial_record.trial_index != trial_index:
71+
logger.info("Re-quantizing with best quantization config...")
72+
del q_model
73+
best_quant_config: BaseConfig = best_trial_record.quant_config
74+
best_quant_model = quantize_model(model, best_quant_config, calib_dataloader, calib_iteration)
75+
else:
76+
best_quant_model = q_model
7077
break
7178
tuning_logger.tuning_end()
7279
return best_quant_model

neural_compressor/torch/quantization/autotune.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def autotune(
7272
baseline: float = eval_func_wrapper.evaluate(model)
7373
tuning_monitor.set_baseline(baseline)
7474
tuning_logger.tuning_start()
75-
for trial_index, quant_config in enumerate(config_loader):
75+
for trial_index, quant_config in enumerate(config_loader, 1):
7676
tuning_logger.trial_start(trial_index=trial_index)
7777
tuning_logger.execution_start()
7878
logger.info(quant_config.to_dict())
@@ -93,10 +93,11 @@ def autotune(
9393
tuning_logger.trial_end(trial_index)
9494
if tuning_monitor.need_stop():
9595
logger.info("Stopped tuning.")
96-
if trial_index == 0: # recover the best q_model from previous results.
97-
logger.info("Reconvering the best quantized model...")
96+
best_trial_record = tuning_monitor.get_best_trial_record()
97+
if best_trial_record.trial_index != trial_index:
98+
logger.info("Re-quantizing with best quantization config...")
9899
del q_model # maybe gc.collect() is needed for memory release
99-
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
100+
best_quant_config: BaseConfig = best_trial_record.quant_config
100101
# !!! Make sure to use deepcopy only when inplace is set to `True`.
101102
q_model = quantize(
102103
deepcopy(model),

neural_compressor/torch/quantization/quantize.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
21-
from neural_compressor.common.utils import Mode, log_process
21+
from neural_compressor.common.utils import Mode, call_counter, log_process
2222
from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig
2323
from neural_compressor.torch.utils import is_ipex_available, logger
2424
from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info
@@ -31,6 +31,7 @@ def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_nam
3131

3232

3333
@log_process(mode=Mode.QUANTIZE)
34+
@call_counter
3435
def quantize(
3536
model: torch.nn.Module,
3637
quant_config: BaseConfig,

test/3x/common/test_utility.py

+22
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import unittest
1212
from unittest.mock import MagicMock, patch
1313

14+
import neural_compressor.common.utils.utility as inc_utils
1415
from neural_compressor.common import options
1516
from neural_compressor.common.utils import (
1617
CpuInfo,
@@ -166,5 +167,26 @@ def __init__(self):
166167
assert instance2.value == 1, "Singleton should return the same instance"
167168

168169

170+
class TestCallCounter(unittest.TestCase):
171+
def test_call_counter(self):
172+
# empty dict
173+
inc_utils.FUNC_CALL_COUNTS.clear()
174+
175+
@inc_utils.call_counter
176+
def add(a, b):
177+
return a + b
178+
179+
# Initial count should be 0
180+
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 0)
181+
182+
# Call the function multiple times
183+
add(1, 2)
184+
add(3, 4)
185+
add(5, 6)
186+
187+
# Count should be incremented accordingly
188+
self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3)
189+
190+
169191
if __name__ == "__main__":
170192
unittest.main()

test/3x/torch/test_autotune.py

+38
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import transformers
88

9+
import neural_compressor.common.utils.utility as inc_utils
910
from neural_compressor.common import logger
1011
from neural_compressor.torch.quantization import (
1112
MixPrecisionConfig,
@@ -163,6 +164,43 @@ def eval_acc_fn(model) -> float:
163164

164165
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
165166
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
167+
print(inc_utils.FUNC_CALL_COUNTS)
168+
self.assertIsNotNone(best_model)
169+
170+
def test_autotune_return_qmodel_directly(self):
171+
inc_utils.FUNC_CALL_COUNTS.clear()
172+
173+
baseline = 1
174+
eval_result = [0.9, 1.1]
175+
acc_list = [baseline] + eval_result
176+
177+
def eval_acc_fn(model) -> float:
178+
acc = acc_list.pop(0)
179+
return acc
180+
181+
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
182+
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
183+
assert (
184+
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 2
185+
), f"quantize should be called twice, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
186+
self.assertIsNotNone(best_model)
187+
188+
def test_autotune_return_re_quant_qmodel(self):
189+
inc_utils.FUNC_CALL_COUNTS.clear()
190+
191+
baseline = 1
192+
eval_result = [0.9, 0.8]
193+
acc_list = [baseline] + eval_result
194+
195+
def eval_acc_fn(model) -> float:
196+
acc = acc_list.pop(0)
197+
return acc
198+
199+
custom_tune_config = TuningConfig(config_set=[RTNConfig(bits=[4, 6])], max_trials=2)
200+
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fn=eval_acc_fn)
201+
assert (
202+
inc_utils.FUNC_CALL_COUNTS.get("quantize") == 3
203+
), f"quantize should be called three times, but got {inc_utils.FUNC_CALL_COUNTS.get('quantize')}"
166204
self.assertIsNotNone(best_model)
167205

168206
@reset_tuning_target

0 commit comments

Comments
 (0)