Skip to content

Commit 70a1d50

Browse files
authored
fix 3x ipex static quant regression (#1864)
Description fix 3x ipex static quant regression cannot fallback with op type name ('linear') dump wrong op stats (no 'Linear&relu' op type) --------- Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent 4e45f8f commit 70a1d50

File tree

3 files changed

+81
-40
lines changed

3 files changed

+81
-40
lines changed

neural_compressor/torch/algorithms/smooth_quant/utility.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from neural_compressor.torch.algorithms.static_quant import (
2828
CpuInfo,
29+
Statistics,
2930
TransformerBasedModelBlockPatternDetector,
30-
dump_model_op_stats,
3131
generate_activation_observer,
3232
get_quantizable_ops_from_cfgs,
3333
ipex_config_path,
@@ -251,6 +251,59 @@ def cfg_to_qconfig(
251251
return None
252252

253253

254+
def dump_model_op_stats(user_cfg):
255+
"""This is a function to dump quantizable ops of model to user.
256+
257+
Args:
258+
user_cfg (dict): quantization config
259+
Returns:
260+
None
261+
"""
262+
res = dict()
263+
for k, v in user_cfg.items():
264+
op_type_list = k[-1].split("><")
265+
op_type = ""
266+
for op in op_type_list:
267+
if "class" in op:
268+
op_type = (
269+
op[op.rfind(".") + 1 : op.rfind("'")]
270+
if op_type == ""
271+
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
272+
)
273+
elif "method" in op:
274+
start = op.find("'") + 1
275+
if start > 1:
276+
op_type = (
277+
op[start : op.find("'", start)]
278+
if op_type == ""
279+
else op_type + "&" + op[start : op.find("'", start)]
280+
)
281+
else:
282+
start = op.find("method") + 7
283+
op_type = (
284+
op[start : op.find(" ", start)]
285+
if op_type == ""
286+
else op_type + "&" + op[start : op.find(" ", start)]
287+
)
288+
else:
289+
op_type = op if op_type == "" else op_type + "&" + op
290+
if op_type not in res.keys():
291+
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
292+
if v["weight"]["dtype"] == "int8":
293+
res[op_type]["INT8"] += 1
294+
elif v["weight"]["dtype"] == "fp32":
295+
res[op_type]["FP32"] += 1
296+
297+
output_data = [
298+
[op_type, sum(res[op_type].values()), res[op_type]["INT8"], res[op_type]["BF16"], res[op_type]["FP32"]]
299+
for op_type in res.keys()
300+
]
301+
302+
Statistics(
303+
output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
304+
).print_stat()
305+
306+
254307
def get_parent(node, all_parents=False): # pragma: no cover
255308
if node.inputs() is None:
256309
return None

neural_compressor/torch/algorithms/static_quant/utility.py

+15-33
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,19 @@
4343
"<class 'torch.nn.modules.conv.Conv2d'>": "Conv2d",
4444
"<class 'torch.nn.modules.conv.Conv3d'>": "Conv3d",
4545
"<class 'torch.nn.modules.activation.ReLU'>": "ReLU",
46+
"<class 'torch.nn.modules.sparse.EmbeddingBag'>": "EmbeddingBag",
4647
"<method 'add' of 'torch._C._TensorBase' objects>": "add", # for IPEX < 2.2
4748
"<method 'add' of 'torch._C.TensorBase' objects>": "add", # for IPEX >= 2.2
4849
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>": "AdaptiveAvgPool2d",
4950
"Linear_Relu": "Linear",
51+
"Linear_add": "Linear",
5052
"<class 'torch.nn.modules.linear.Linear'>": "Linear",
5153
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "MaxPool2d",
52-
"re": {"<built-in method matmul of type object at": "matmul"},
54+
"re": {
55+
"<built-in method matmul of type object at": "matmul",
56+
"<built-in method add of type object at": "add",
57+
"<built-in method bmm of type object at": "bmm",
58+
},
5359
}
5460

5561
BLOCK_PATTERNS = [
@@ -85,6 +91,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
8591
Returns:
8692
cfgs (dict): updated configs.
8793
"""
94+
ori_user_cfg = copy.deepcopy(user_cfg)
8895
tmp_user_cfg = OrderedDict()
8996
for op in user_cfg: # map ipex op_name to pt op_name
9097
for i, op_name in enumerate(op):
@@ -94,9 +101,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
94101
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
95102
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
96103
break
97-
user_cfg = tmp_user_cfg
98-
for op_name in user_cfg:
99-
inc_op_cfg = user_cfg[op_name]
104+
105+
for op_name in tmp_user_cfg:
106+
inc_op_cfg = tmp_user_cfg[op_name]
100107
for i, name in enumerate(op_name[0]):
101108
# to int8
102109
ipex_op_cfg = op_infos_from_cfgs[name]
@@ -154,7 +161,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
154161
else:
155162
pass
156163
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
157-
return cfgs, user_cfg
164+
return cfgs, ori_user_cfg
158165

159166

160167
def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
@@ -333,8 +340,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
333340
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
334341
method = ipex_op_type.split("'")[1]
335342
op_name_info.append((module_fqn, method))
336-
elif "Convolution" in ipex_op_type: # "Convolution_Relu"
337-
op_name_info.append((module_fqn, "Conv2d"))
343+
elif "_" in ipex_op_type: # "Convolution_Relu", "Linear_Relu"
344+
op_name_info.append((module_fqn, ipex_op_type.split("_")[0]))
338345
else:
339346
re_flag = False
340347
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
@@ -394,32 +401,7 @@ def dump_model_op_stats(user_cfg):
394401
"""
395402
res = dict()
396403
for k, v in user_cfg.items():
397-
op_type_list = k[-1].split("><")
398-
op_type = ""
399-
for op in op_type_list:
400-
if "class" in op:
401-
op_type = (
402-
op[op.rfind(".") + 1 : op.rfind("'")]
403-
if op_type == ""
404-
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
405-
)
406-
elif "method" in op:
407-
start = op.find("'") + 1
408-
if start > 1:
409-
op_type = (
410-
op[start : op.find("'", start)]
411-
if op_type == ""
412-
else op_type + "&" + op[start : op.find("'", start)]
413-
)
414-
else:
415-
start = op.find("method") + 7
416-
op_type = (
417-
op[start : op.find(" ", start)]
418-
if op_type == ""
419-
else op_type + "&" + op[start : op.find(" ", start)]
420-
)
421-
else:
422-
op_type = op if op_type == "" else op_type + "&" + op
404+
op_type = k[1]
423405
if op_type not in res.keys():
424406
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
425407
if v["weight"]["dtype"] == "int8":

test/3x/torch/quantization/test_static_quant.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@ class Model(torch.nn.Module):
2222
def __init__(self):
2323
super(Model, self).__init__()
2424
self.fc1 = torch.nn.Linear(30, 50)
25-
self.fc2 = torch.nn.Linear(50, 30)
26-
self.fc3 = torch.nn.Linear(30, 5)
25+
self.fc2 = torch.nn.Linear(50, 50)
26+
self.fc3 = torch.nn.Linear(50, 30)
27+
self.fc4 = torch.nn.Linear(30, 5)
28+
self.relu = torch.nn.ReLU()
2729

2830
def forward(self, x):
2931
out = self.fc1(x)
3032
out = self.fc2(out)
33+
out = self.relu(out)
3134
out = self.fc3(out)
35+
out = out + x
36+
out = self.fc4(out)
3237
return out
3338

3439
model = Model()
@@ -78,21 +83,22 @@ def test_static_quant_fallback(self):
7883
assert q_model is not None, "Quantization failed!"
7984

8085
for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
81-
if op_info["op_type"] == "<class 'torch.nn.modules.linear.Linear'>":
86+
if op_info["op_type"] == "Linear":
8287
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
8388
assert dtype == "torch.float32", "Failed to fallback linear op, please check!"
8489

8590
# fallback by op_name
86-
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
91+
quant_config = get_default_static_config()
92+
quant_config.set_local("fc2", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
8793
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
8894
run_fn(prepared_model)
8995
q_model = convert(prepared_model)
9096
assert q_model is not None, "Quantization failed!"
9197

9298
for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
93-
if op_info["fqn"] == "fc1":
99+
if op_info["fqn"] == "fc2":
94100
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
95-
assert dtype == "torch.float32", "Failed to fallback fc1 layer, please check!"
101+
assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!"
96102

97103
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
98104
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)