Skip to content

Commit 855c10c

Browse files
authored
map ipex op_name w/ pt op_name (#1740)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent e87c95f commit 855c10c

File tree

6 files changed

+58
-18
lines changed

6 files changed

+58
-18
lines changed

neural_compressor/common/base_config.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,9 @@ def to_config_mapping(
410410
if self.global_config is not None:
411411
config_mapping[(op_name, op_type)] = global_config
412412
if op_type in op_type_config_dict:
413-
config_mapping[(op_name, op_type)] = op_name_config_dict[op_type]
413+
config_mapping[(op_name, op_type)] = op_type_config_dict[op_type]
414414
for op_name_pattern in op_name_config_dict:
415-
if isinstance(op_name, str) and re.match(op_name_pattern, op_name):
416-
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
417-
elif op_name_pattern == op_name: # TODO: map ipex opname to stock pt op_name
415+
if re.match(op_name_pattern, op_name):
418416
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
419417
return config_mapping
420418

neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
5656
"""
5757
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
5858

59-
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
59+
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
6060

6161
# check smoothquant folding value
6262
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
@@ -121,7 +121,7 @@ def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
121121
with open(ipex_config_path, "r") as f:
122122
model.tune_cfg = json.load(f)
123123
model.ipex_config_path = ipex_config_path
124-
dump_model_op_stats(tune_cfg)
124+
dump_model_op_stats(tune_cfg["op"])
125125
return model
126126

127127

@@ -185,7 +185,7 @@ def qdq_quantize(
185185
with open(ipex_config_path, "r") as f:
186186
model.tune_cfg = json.load(f)
187187
model.ipex_config_path = ipex_config_path
188-
dump_model_op_stats(tune_cfg)
188+
dump_model_op_stats(tune_cfg["op"])
189189
return model
190190

191191

neural_compressor/torch/algorithms/static_quant/static_quant.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
5151
Returns:
5252
A quantized model.
5353
"""
54-
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name = get_quantizable_ops_recursively(model, example_inputs)
55-
cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) # update json file in ipex_config_path
54+
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
55+
# update json file in ipex_config_path; map ipex op_name to pt op_name
56+
user_cfg = cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
5657
model.eval()
5758

5859
# Check save_qconf_summary part is a workaround for IPEX bug.
@@ -82,7 +83,7 @@ def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
8283
with open(ipex_config_path, "r") as f:
8384
model.tune_cfg = json.load(f)
8485
model.ipex_config_path = ipex_config_path
85-
dump_model_op_stats(tune_cfg)
86+
dump_model_op_stats(user_cfg)
8687
return model
8788

8889

neural_compressor/torch/algorithms/static_quant/utility.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import os
1818
import re
19+
from collections import OrderedDict
1920
from typing import Dict, List, Union
2021

2122
import torch
@@ -66,9 +67,10 @@
6667
def cfg_to_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name): # pragma: no cover
6768
assert cfgs is not None, "No configure for IPEX int8 model..."
6869
op_infos = copy.deepcopy(op_infos_from_cfgs)
69-
cfgs = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
70+
cfgs, user_cfg = check_cfg_and_qconfig(tune_cfg["op"], cfgs, op_infos, output_tensor_id_op_name)
7071
with open(ipex_config_path, "w") as write_f:
7172
json.dump(cfgs, write_f, indent=4)
73+
return user_cfg
7274

7375

7476
def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover
@@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
8385
Returns:
8486
cfgs (dict): updated configs.
8587
"""
88+
tmp_user_cfg = OrderedDict()
89+
for op in user_cfg: # map ipex op_name to pt op_name
90+
for i, op_name in enumerate(op):
91+
for ops, _ in op_infos_from_cfgs.items():
92+
if "fqn" in op_infos_from_cfgs[ops].keys() and op_infos_from_cfgs[ops]["fqn"] == op_name:
93+
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
94+
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
95+
break
96+
user_cfg = tmp_user_cfg
8697
for op_name in user_cfg:
8798
inc_op_cfg = user_cfg[op_name]
8899
for i, name in enumerate(op_name[0]):
@@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
142153
else:
143154
pass
144155
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
145-
return cfgs
156+
return cfgs, user_cfg
146157

147158

148159
def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
@@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
212223
cfgs (dict): dict of configuration
213224
"""
214225
quantizable_ops = []
226+
op_name_info = []
215227
# group ops by position for transform-based model
216228
detector = TransformerBasedModelBlockPatternDetector(model)
217229
detect_result = detector.detect_block()
@@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
277289
if ipex_op_type in unify_op_type_mapping_ipex:
278290
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
279291
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
292+
if "class" in ipex_op_type: # "<class 'torch.nn.modules.activation.ReLU'>"
293+
op_type = ipex_op_type.split("'")[1]
294+
op_name_info.append((module_fqn, eval(op_type)))
295+
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
296+
method = ipex_op_type.split("'")[1]
297+
op_type = getattr(
298+
torch._C._TensorBase if ipex_ver.release < Version("2.2") else torch._C.TensorBase, method
299+
)
300+
op_name_info.append((module_fqn, op_type))
301+
else:
302+
op_name_info.append((module_fqn, op_type))
280303
else:
281304
re_flag = False
282305
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
283306
if re.match(pattern, ipex_op_type):
284307
re_flag = True
285308
quantizable_ops.append((tuple(name), unify_op_type))
286309
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
310+
op_name_info.append((module_fqn, ipex_op_type))
287311
break
288312
if not re_flag:
289313
quantizable_ops.append((tuple(name), ipex_op_type))
290314
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
315+
op_name_info.append((module_fqn, ipex_op_type))
291316
else:
292317
op_type = ""
293318
for op_name in name:
@@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
302327
_op_cfg_id = name[0][2]
303328
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
304329
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn
330+
op_name_info.append((module_fqn, op_type))
305331

306332
logger.debug("Map op name to fqn: ")
307333
logger.debug(map_op_name_to_fqn)
308334
logger.info("Attention Blocks : ")
309335
logger.info(attention_block)
310336
logger.info("FFN Blocks : ")
311337
logger.info(ffn_blocks)
312-
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name
338+
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, op_name_info
313339

314340

315341
def simple_inference(q_model, example_inputs, iterations=1):
@@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
323349
q_model(example_inputs)
324350

325351

326-
def dump_model_op_stats(tune_cfg):
352+
def dump_model_op_stats(user_cfg):
327353
"""This is a function to dump quantizable ops of model to user.
328354
329355
Args:
330-
tune_cfg (dict): quantization config
356+
user_cfg (dict): quantization config
331357
Returns:
332358
None
333359
"""
334360
res = dict()
335-
for k, v in tune_cfg["op"].items():
361+
for k, v in user_cfg.items():
336362
op_type_list = k[-1].split("><")
337363
op_type = ""
338364
for op in op_type_list:

neural_compressor/torch/quantization/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
818818
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
819819
from neural_compressor.torch.algorithms.static_quant import get_quantizable_ops_recursively
820820

821-
model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
821+
_, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
822822
return model_info
823823

824824
@classmethod
@@ -923,7 +923,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
923923
def get_model_info(model: torch.nn.Module, example_inputs) -> List[Tuple[str, Callable]]:
924924
from neural_compressor.torch.algorithms.smooth_quant import get_quantizable_ops_recursively
925925

926-
model_info, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
926+
model_info, _, _, _, _ = get_quantizable_ops_recursively(model, example_inputs=example_inputs)
927927
return model_info
928928

929929
@classmethod

test/3x/torch/quantization/test_static_quant.py

+15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ def test_static_quant_default(self):
4949
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
5050
assert q_model is not None, "Quantization failed!"
5151

52+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
53+
def test_static_quant_fallback(self):
54+
fp32_model = copy.deepcopy(self.fp32_model)
55+
quant_config = get_default_static_config()
56+
example_inputs = self.input
57+
# fallback by op_type
58+
quant_config.set_local(torch.nn.modules.linear.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
59+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
60+
assert q_model is not None, "Quantization failed!"
61+
62+
# fallback by op_name
63+
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
64+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
65+
assert q_model is not None, "Quantization failed!"
66+
5267
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
5368
@pytest.mark.parametrize(
5469
"act_sym, act_algo",

0 commit comments

Comments
 (0)