Skip to content

Commit 1f58f02

Browse files
authored
Add set_local support for static quant with pt2e (#1870)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 0341295 commit 1f58f02

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def get_filter_fn(node_list, fn):
106106
def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
107107
"""Filter the node with target operator in match and check if it is in `node_list`."""
108108
target_node = None
109-
for node in pattern_graph.nodes:
109+
for node in pattern_graph.nodes: # pragma: no cover
110110
if node.target == target_op:
111111
target_node = node
112112
break
113-
if target_node is None:
113+
if target_node is None: # pragma: no cover
114114
return False
115115
matched_node = match.nodes_map[target_node]
116116
return matched_node in node_list
@@ -137,7 +137,8 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
137137
for node in gm.graph.nodes:
138138
if meta := getattr(node, "meta"):
139139
if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY):
140-
if quantization_annotation._annotated:
140+
none_annotation = xiq._X86InductorQuantizationAnnotation(_annotated=True)
141+
if quantization_annotation != none_annotation: # pragma: no cover
141142
continue
142143
unquantized_node_set.add(node)
143144
return unquantized_node_set
@@ -161,18 +162,18 @@ def _parse_node_candidate_set_from_user_config(config, gm):
161162
op_type_configs, op_name_configs = config._get_op_name_op_type_config()
162163
op_type_filters = []
163164
op_name_filters = []
164-
for op_type_name, config in op_type_configs.items():
165+
for op_type_name, config in op_type_configs.items(): # pragma: no cover
165166
op_type = getattr(torch.nn, op_type_name)
166-
if config.act_dtype == "fp16":
167+
if config.act_dtype == "fp16": # pragma: no cover
167168
filter = xpq._get_module_type_filter(op_type)
168169
op_type_filters.append(filter)
169170
for op_name, config in op_name_configs.items():
170-
if config.act_dtype == "fp16":
171+
if config.act_dtype == "fp16": # pragma: no cover
171172
filter = xpq._get_module_name_filter(op_name)
172173
op_name_filters.append(filter)
173174
node_set_from_user_config = set()
174175
all_filters = op_type_filters + op_name_filters
175-
for node in gm.graph.nodes:
176+
for node in gm.graph.nodes: # pragma: no cover
176177
if any([filter(node) for filter in all_filters]):
177178
node_set_from_user_config.add(node)
178179
return node_set_from_user_config

neural_compressor/torch/algorithms/pt2e_quant/utility.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch.ao.quantization.quantizer import QuantizationSpec
2121
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
2222

23+
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2
24+
2325

2426
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
2527
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
@@ -53,6 +55,9 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
5355

5456

5557
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
58+
NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"]
59+
if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover
60+
return None
5661
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
5762
input_act_quant_spec = create_quant_spec_from_config(
5863
inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic
@@ -75,5 +80,22 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct
7580
# set global
7681
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
7782
quantizer.set_global(global_config)
78-
# Skip the local config for now (need torch 2.4)
83+
# need torch >= 2.3.2
84+
if GT_TORCH_VERSION_2_3_2: # pragma: no cover
85+
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
86+
if op_type_config_dict:
87+
for op_type, config in op_type_config_dict.items():
88+
_nn_module_type = getattr(torch.nn, op_type, None)
89+
if _nn_module_type:
90+
quantizer.set_module_type_qconfig(
91+
_nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic)
92+
)
93+
_nn_func_type = getattr(torch.nn.functional, op_type, None)
94+
if _nn_func_type:
95+
quantizer.set_function_type_qconfig(
96+
_nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic)
97+
)
98+
if op_name_config_dict:
99+
for op_name, config in op_name_config_dict.items():
100+
quantizer.set_module_name_qconfig(op_name, _map_inc_config_to_torch_quant_config(config, is_dynamic))
79101
return quantizer

neural_compressor/torch/utils/environ.py

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def get_torch_version():
9191
return version
9292

9393

94+
GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2")
95+
96+
9497
def get_accelerator(device_name="auto"):
9598
global accelerator # update the global accelerator when calling this func
9699
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

test/3x/torch/quantization/test_pt2e_quant.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
prepare,
1818
quantize,
1919
)
20-
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version
20+
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version
2121

2222
torch.manual_seed(0)
2323

@@ -119,6 +119,42 @@ def calib_fn(model):
119119
logger.warning("out shape is %s", out.shape)
120120
assert out is not None
121121

122+
@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2")
123+
def test_quantize_simple_model_with_set_local(self, force_not_import_ipex):
124+
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
125+
float_model_output = model(*example_inputs)
126+
quant_config = None
127+
128+
def calib_fn(model):
129+
for i in range(4):
130+
model(*example_inputs)
131+
132+
quant_config = get_default_static_config()
133+
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
134+
q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn)
135+
136+
# check the half node
137+
expected_node_occurrence = {
138+
# Only quantize the `fc2`
139+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
140+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
141+
}
142+
expected_node_occurrence = {
143+
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
144+
}
145+
node_in_graph = self.get_node_in_graph(q_model)
146+
for node, cnt in expected_node_occurrence.items():
147+
assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"
148+
149+
from torch._inductor import config
150+
151+
config.freezing = True
152+
q_model_out = q_model(*example_inputs)
153+
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"
154+
opt_model = torch.compile(q_model)
155+
out = opt_model(*example_inputs)
156+
assert out is not None
157+
122158
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
123159
@pytest.mark.parametrize("is_dynamic", [False, True])
124160
def test_prepare_and_convert_on_simple_model(self, is_dynamic, force_not_import_ipex):
@@ -193,9 +229,9 @@ def get_node_in_graph(graph_module):
193229
nodes_in_graph[n] += 1
194230
else:
195231
nodes_in_graph[n] = 1
196-
return
232+
return nodes_in_graph
197233

198-
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
234+
@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0")
199235
def test_mixed_fp16_and_int8(self, force_not_import_ipex):
200236
model, example_inputs = self.build_model_include_conv_and_linear()
201237
model = export(model, example_inputs=example_inputs)
@@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
221257
}
222258
node_in_graph = self.get_node_in_graph(converted_model)
223259
for node, cnt in expected_node_occurrence.items():
224-
assert (
225-
expected_node_occurrence.get(node, 0) == cnt
226-
), f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"
260+
assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}"
227261

228262
# inference
229263
from torch._inductor import config

0 commit comments

Comments
 (0)