Skip to content

Commit e6664b0

Browse files
authoredFeb 28, 2024··
add fp8 autotune ut and fix bug in autotune (#1638)
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent b4e37b7 commit e6664b0

File tree

4 files changed

+53
-6
lines changed

4 files changed

+53
-6
lines changed
 

‎neural_compressor/torch/quantization/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
get_default_hqq_config,
3232
FP8Config,
3333
get_default_fp8_config,
34+
get_default_fp8_config_set,
3435
)
3536

3637
from neural_compressor.torch.quantization.autotune import (

‎neural_compressor/torch/quantization/autotune.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,13 @@ def autotune(
7272
tuning_logger.trial_end(trial_index)
7373
if tuning_monitor.need_stop():
7474
logger.info("Stopped tuning.")
75+
del q_model # maybe gc.collect() is needed for memory release
7576
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
7677
# !!! Make sure to use deepcopy only when inplace is set to `True`.
77-
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
78-
best_quant_model = model # quantize model inplace
78+
q_model = quantize(
79+
deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True
80+
)
81+
best_quant_model = q_model # quantize model inplace
7982
break
8083
tuning_logger.tuning_end()
8184
return best_quant_model

‎neural_compressor/torch/quantization/config.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -941,14 +941,23 @@ def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]
941941

942942

943943
def get_default_fp8_config() -> FP8Config:
944-
"""Generate the default gptq config.
944+
"""Generate the default fp8 config.
945945
946946
Returns:
947-
the default gptq config.
947+
the default fp8 config.
948948
"""
949949
return FP8Config()
950950

951951

952+
def get_default_fp8_config_set() -> FP8Config:
953+
"""Generate the default fp8 config set.
954+
955+
Returns:
956+
the default fp8 config.
957+
"""
958+
return FP8Config.get_config_set_for_tuning()
959+
960+
952961
##################### Algo Configs End ###################################
953962

954963

‎test/3x/torch/quantization/habana_fp8/test_fp8.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818
FP8Matmul,
1919
Matmul,
2020
)
21-
from neural_compressor.torch.quantization import quantize
22-
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
21+
from neural_compressor.torch.quantization import (
22+
FP8Config,
23+
TuningConfig,
24+
autotune,
25+
get_default_fp8_config,
26+
get_default_fp8_config_set,
27+
quantize,
28+
)
2329

2430
torch.set_grad_enabled(False)
2531

@@ -164,3 +170,31 @@ def calib_func(model):
164170
assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check."
165171
assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check."
166172
assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check."
173+
174+
def test_autotune(self):
175+
m = copy.deepcopy(self.model)
176+
inp = self.inp
177+
fp32_out = m(inp)
178+
179+
def calib_func(model):
180+
model(inp)
181+
182+
accu_list = [1.0, 0.9, 0.99]
183+
184+
def eval_func(model):
185+
nonlocal accu_list
186+
return accu_list.pop()
187+
188+
tune_config = TuningConfig(
189+
config_set=get_default_fp8_config_set(),
190+
tolerable_loss=0.01,
191+
)
192+
best_model = autotune(
193+
model=m,
194+
tune_config=tune_config,
195+
run_fn=calib_func,
196+
eval_fns=eval_func,
197+
)
198+
assert isinstance(best_model.fc1, FP8Linear), "Unexpected result. Please double check."
199+
assert isinstance(best_model.mm, FP8Matmul), "Unexpected result. Please double check."
200+
assert isinstance(best_model.bmm, FP8BatchMatmul), "Unexpected result. Please double check."

0 commit comments

Comments
 (0)
Please sign in to comment.