Skip to content

Commit a617115

Browse files
yiliu30xin3he
andauthored
Add autotune support for PT2E (#2110)
Add autotune support for PT2E and disable some conv1d-related test on HPU --------- Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: Xin He <xin3.he@intel.com>
1 parent d2e49d2 commit a617115

File tree

7 files changed

+138
-34
lines changed

7 files changed

+138
-34
lines changed

neural_compressor/common/base_config.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import copy
2122
import inspect
2223
import json
2324
import os
@@ -539,6 +540,7 @@ def expand(self) -> List[BaseConfig]:
539540
tuning_param_pair = dict(zip(tuning_param_name_lst, params_values))
540541
tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair}
541542
new_config = self.__class__(**tmp_params_dict)
543+
new_config.local_config = copy.deepcopy(self.local_config)
542544
logger.info(new_config.to_dict())
543545
config_list.append(new_config)
544546
logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list))
@@ -629,9 +631,13 @@ def __eq__(self, other: BaseConfig) -> bool:
629631
"""
630632
if not isinstance(other, type(self)):
631633
return False
632-
return self.params_list == other.params_list and all(
634+
635+
params_equal = self.params_list == other.params_list and all(
633636
getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list
634637
)
638+
local_config_equal = self.local_config == other.local_config
639+
global_config_equal = self.global_config == other.global_config
640+
return params_equal and local_config_equal and global_config_equal
635641

636642

637643
class ComposableConfig(BaseConfig):

neural_compressor/torch/algorithms/pt2e_quant/core.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from typing import Any
2020

21+
import torch
2122
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
2223
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2324
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
@@ -102,4 +103,5 @@ def half_precision_transformation(self, model, config):
102103
"""
103104
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
104105
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
105-
hp_rewriter.transformation(model, half_precision_node_set)
106+
hp_rewriter.transformation(model, half_precision_node_set, torch.float16)
107+
hp_rewriter.transformation(model, half_precision_node_set, torch.bfloat16)

neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py

+46-20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Rewrite the FP32 operators to FP16 or BF16 operators."""
1515

16+
from collections import defaultdict
1617
from dataclasses import dataclass
1718
from functools import partial
1819
from typing import Any, Callable, Dict, List, Tuple
@@ -25,7 +26,7 @@
2526
from torch.fx.subgraph_rewriter import Match
2627
from typing_extensions import TypeAlias
2728

28-
from neural_compressor.common import utils
29+
from neural_compressor.common import logger, utils
2930

3031
# =============================================================================
3132
# Search and replace patterns
@@ -50,25 +51,44 @@ class PatternPair:
5051

5152
# key: torch func
5253
# value: the tuple of args
53-
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]]
54+
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, List[Tuple[torch.Tensor, ...]]]
5455

5556

56-
# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
57-
# TODO: complete the mapping
57+
# Align with xiq, as it relay on xiq's set_module_xx capability
5858
FN_ARGS_MAPPING: FuncArgsMappingType = {
59-
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
60-
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
59+
# Note: ORDER is matter
60+
torch.nn.functional.linear: [
61+
(torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
62+
(torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
63+
],
64+
torch.nn.functional.conv2d: [
65+
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias
66+
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias
67+
],
68+
torch.matmul: [
69+
(torch.randn(0, 0), torch.randn(0, 0)),
70+
(torch.randn(0, 0, 0), torch.randn(0, 0, 0)),
71+
(torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)),
72+
],
6173
}
62-
# TODO: complete the mapping
63-
FN_ATEN_OPS_MAPPING = {
64-
torch.nn.functional.linear: torch.ops.aten.linear.default,
74+
75+
# module cls <-> function name
76+
NN_MODULES_TO_NN_FN = {
77+
torch.nn.Linear: torch.nn.functional.linear,
78+
torch.nn.Conv2d: torch.nn.functional.conv2d,
6579
}
6680

81+
# Use the mapping from xiq
82+
FN_ATEN_OPS_MAPPING = xiq._map_module_function_to_aten_operator_type()
83+
6784
SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING.values()
6885

6986

7087
PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair]
71-
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}}
88+
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {
89+
torch.float16: defaultdict(list),
90+
torch.bfloat16: defaultdict(list),
91+
}
7292

7393
# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
7494
# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
@@ -98,15 +118,18 @@ def replace_fn_wrapper(fn_args, fn):
98118

99119

100120
def _register_pattern_pair(dtype: torch.dtype) -> None:
101-
for fn, fn_args in FN_ARGS_MAPPING.items():
102-
pattern_pair = pattern_factory(fn, fn_args)
103-
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair
104-
utils.logger.info(
121+
for fn, fn_args_lst in FN_ARGS_MAPPING.items():
122+
for fn_args in fn_args_lst:
123+
logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.")
124+
pattern_pair = pattern_factory(fn, fn_args)
125+
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn].append(pattern_pair)
126+
utils.logger.debug(
105127
f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}."
106128
)
107129

108130

109131
_register_pattern_pair(torch.float16)
132+
_register_pattern_pair(torch.bfloat16)
110133

111134

112135
def get_filter_fn(node_list, fn):
@@ -182,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
182205

183206
def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16):
184207
"""Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
185-
for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
186-
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
187-
utils.logger.info("Half precision conversion is done:")
208+
for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
209+
for pattern_pair in pattern_pair_lst:
210+
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
211+
utils.logger.info(f"Half precision conversion({target_dtype}) completed.")
188212
if utils.level_name == "DEBUG": # pragma: no cover
189213
gm.print_readable(True)
190214

@@ -201,11 +225,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
201225
op_name_filters = []
202226
for op_type_name, config in op_type_configs.items(): # pragma: no cover
203227
op_type = getattr(torch.nn, op_type_name)
204-
if config.act_dtype == "fp16": # pragma: no cover
228+
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
205229
filter = xpq._get_module_type_filter(op_type)
206230
op_type_filters.append(filter)
207231
for op_name, config in op_name_configs.items():
208-
if config.act_dtype == "fp16": # pragma: no cover
232+
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
209233
filter = xpq._get_module_name_filter(op_name)
210234
op_name_filters.append(filter)
211235
node_set_from_user_config = set()
@@ -237,5 +261,7 @@ def get_half_precision_node_set(gm, config):
237261
for node in possible_node_set:
238262
if node.target in SUPPORTED_OPERATORS:
239263
half_precision_node_set.add(node)
240-
utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.")
264+
utils.logger.info(
265+
f"Found {len(half_precision_node_set)} nodes to convert to half precision: {half_precision_node_set}"
266+
)
241267
return half_precision_node_set

neural_compressor/torch/algorithms/pt2e_quant/utility.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch.ao.quantization.quantizer import QuantizationSpec
2727
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
2828

29-
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5
29+
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, logger
3030

3131

3232
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
@@ -79,6 +79,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
7979
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
8080
NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"]
8181
if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover
82+
logger.debug("Got non-quantizable data types, skipping quantization.")
8283
return None
8384
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
8485
input_act_quant_spec = create_quant_spec_from_config(

neural_compressor/torch/quantization/autotune.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
5454
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
5555

5656

57+
def _deepcopy_warp(model):
58+
additional_attr_lst = ["_exported", "dynamic_shapes"]
59+
original_attr = {key: getattr(model, key, None) for key in additional_attr_lst}
60+
new_model = deepcopy(model)
61+
for key, value in original_attr.items():
62+
setattr(new_model, key, value)
63+
return new_model
64+
65+
5766
@dump_elapsed_time("Pass auto-tune")
5867
def autotune(
5968
model: torch.nn.Module,
@@ -81,7 +90,7 @@ def autotune(
8190
best_quant_model = None
8291
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
8392
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
84-
baseline: float = eval_func_wrapper.evaluate(deepcopy(model))
93+
baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model))
8594
tuning_monitor.set_baseline(baseline)
8695
tuning_logger.tuning_start()
8796
for trial_index, quant_config in enumerate(config_loader, 1):
@@ -90,7 +99,7 @@ def autotune(
9099
logger.info(quant_config.to_dict())
91100
# !!! Make sure to use deepcopy only when inplace is set to `True`.
92101
q_model = quantize(
93-
deepcopy(model),
102+
_deepcopy_warp(model),
94103
quant_config=quant_config,
95104
run_fn=run_fn,
96105
run_args=run_args,
@@ -112,7 +121,7 @@ def autotune(
112121
best_quant_config: BaseConfig = best_trial_record.quant_config
113122
# !!! Make sure to use deepcopy only when inplace is set to `True`.
114123
q_model = quantize(
115-
deepcopy(model),
124+
_deepcopy_warp(model),
116125
quant_config=best_quant_config,
117126
run_fn=run_fn,
118127
run_args=run_args,

test/3x/torch/quantization/test_pt2e_quant.py

+66-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def _is_ipex_imported():
2929
monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported)
3030
monkeypatch.setattr("neural_compressor.torch.export.pt2e_export.is_ipex_imported", _is_ipex_imported)
3131

32-
3332
class TestPT2EQuantization:
3433
def teardown_class(self):
3534
shutil.rmtree("saved_results", ignore_errors=True)
@@ -53,15 +52,15 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
5352
return bar, example_inputs
5453

5554
@staticmethod
56-
def build_model_include_conv_and_linear():
55+
def build_model_include_conv_and_linear(bias=True):
5756
class Model(torch.nn.Module):
58-
def __init__(self):
57+
def __init__(self, bias=True):
5958
super(Model, self).__init__()
60-
self.conv1 = torch.nn.Conv2d(3, 6, 5)
59+
self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=bias)
6160
self.pool = torch.nn.MaxPool2d(2, 2)
62-
self.conv2 = torch.nn.Conv2d(6, 16, 5)
63-
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
64-
self.fc2 = torch.nn.Linear(120, 84)
61+
self.conv2 = torch.nn.Conv2d(6, 16, 5, bias=bias)
62+
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120, bias=bias)
63+
self.fc2 = torch.nn.Linear(120, 84, bias=bias)
6564

6665
def forward(self, x):
6766
x = self.conv1(x)
@@ -74,7 +73,7 @@ def forward(self, x):
7473

7574
return x
7675

77-
model = Model()
76+
model = Model(bias)
7877
example_inputs = (torch.randn(1, 3, 32, 32),)
7978
return model, example_inputs
8079

@@ -283,3 +282,62 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
283282
opt_model = torch.compile(converted_model)
284283
out = opt_model(*example_inputs)
285284
assert out is not None
285+
286+
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
287+
@pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"])
288+
@pytest.mark.parametrize("op_name_or_type", ["conv1", "fc1", torch.nn.Linear, torch.nn.Conv2d])
289+
@pytest.mark.parametrize("bias", [True, False])
290+
def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name_or_type, bias, force_not_import_ipex):
291+
# Just make sure the pattern matches, not the accuracy.
292+
# config1: int8 for all
293+
# config2: half precision for linear/conv
294+
from neural_compressor.torch.quantization.config import INT8StaticQuantConfig
295+
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig
296+
297+
config1 = INT8StaticQuantConfig()
298+
config2 = INT8StaticQuantConfig().set_local(
299+
op_name_or_type, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype)
300+
)
301+
tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1)
302+
eval_result = [1, 1, 2]
303+
304+
def fake_eval_fn(model):
305+
res = eval_result.pop(0)
306+
return res
307+
308+
def run_fn(model):
309+
for i in range(2):
310+
model(*example_inputs)
311+
312+
model, example_inputs = self.build_model_include_conv_and_linear(bias)
313+
model = export(model, example_inputs=example_inputs)
314+
qmodel = autotune(
315+
model=model, tune_config=tune_config, eval_fn=fake_eval_fn, run_fn=run_fn, example_inputs=example_inputs
316+
)
317+
318+
# Calculate the expected number of `aten.to` operations based on bias and op_name_or_type
319+
"""
320+
| Bias | op_name | nn.Module |
321+
|-------|---------|-----------|
322+
| True | 4 | 8 |
323+
| False | 3 | 6 |
324+
"""
325+
expected_node_occurrence = {
326+
torch.ops.aten.to.dtype: (3 + int(bias)) * (1 if isinstance(op_name_or_type, str) else 2)
327+
}
328+
329+
expected_node_occurrence = {
330+
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
331+
}
332+
node_in_graph = self.get_node_in_graph(qmodel)
333+
for node, cnt in expected_node_occurrence.items():
334+
assert (
335+
node_in_graph.get(node, 0) == cnt
336+
), f"Node {node} should occur {cnt} times, but {node_in_graph.get(node, 0)}"
337+
# inference
338+
from torch._inductor import config
339+
340+
config.freezing = True
341+
opt_model = torch.compile(qmodel)
342+
out = opt_model(*example_inputs)
343+
assert out is not None

test/3x/torch/quantization/weight_only/test_rtn.py

+2
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ def test_rtn_with_quantize_API(self):
309309
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."
310310

311311
# TODO: (4, True, 32, 0), group_dim=0, format not supported
312+
# TODO [SW-216127]: it's not in high priority, so we can implement it later.
313+
@pytest.mark.skipif(is_hpex_available(), reason="These tests are not supported on HPU for now.")
312314
@pytest.mark.parametrize(
313315
"bits, use_sym, group_size, group_dim",
314316
[

0 commit comments

Comments
 (0)