Skip to content

Commit ba16504

Browse files
authored
fix tune_cfg issue for 3.x static quant (#1718)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent 137fa3a commit ba16504

File tree

3 files changed

+199
-476
lines changed

3 files changed

+199
-476
lines changed

neural_compressor/torch/algorithms/smooth_quant/utility.py

+2-288
Original file line numberDiff line numberDiff line change
@@ -16,125 +16,28 @@
1616
import json
1717
import os
1818
import re
19-
import subprocess
2019
from collections import UserDict
2120

22-
import cpuinfo
2321
import intel_extension_for_pytorch as ipex
2422
import numpy
25-
import psutil
2623
import torch
2724
import tqdm
2825
from packaging.version import Version
2926

3027
from neural_compressor.torch.algorithms.static_quant import (
3128
TransformerBasedModelBlockPatternDetector,
3229
dump_model_op_stats,
33-
get_quantizable_ops_from_cfgs,
30+
generate_activation_observer,
31+
get_quantizable_ops_recursively,
3432
ipex_config_path,
35-
paser_cfgs,
3633
simple_inference,
37-
unify_op_type_mapping_ipex,
3834
)
3935
from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger
4036

4137
version = get_torch_version()
4238
ipex_ver = get_ipex_version()
4339

4440

45-
def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
46-
"""This is a helper method to generate an activation observer.
47-
48-
Args:
49-
scheme (str): Quantization scheme to be used.
50-
algorithm (str): What algorithm for computing the quantization parameters based on.
51-
52-
Returns:
53-
An observer.
54-
"""
55-
kl_activation_observer = {
56-
"name": "HistogramObserver",
57-
"bins": 2048,
58-
"upsample_rate": 128,
59-
"dtype": "torch.quint8",
60-
"qscheme": "torch.per_tensor_affine",
61-
"reduce_range": False,
62-
"quant_min": 0,
63-
"quant_max": 255,
64-
}
65-
minmax_activation_observer = {
66-
"name": "MinMaxObserver",
67-
"dtype": "torch.quint8",
68-
"qscheme": "torch.per_tensor_affine",
69-
"reduce_range": False,
70-
"quant_min": 0,
71-
"quant_max": 255,
72-
}
73-
smoothquant_kl_activation_observer = {
74-
"name": "SmoothQuantActivationObserver",
75-
"smooth_quant_enabled": smooth_quant_enable,
76-
"dtype": "torch.quint8",
77-
"qscheme": "torch.per_tensor_affine",
78-
"reduce_range": False,
79-
"quant_min": 0,
80-
"quant_max": 255,
81-
"alpha": 0.5,
82-
"act_observer": kl_activation_observer,
83-
"act_ic_observer": {
84-
"name": "PerChannelMinMaxObserver",
85-
"ch_axis": -1,
86-
"dtype": "torch.quint8",
87-
"qscheme": "torch.per_channel_affine",
88-
"reduce_range": False,
89-
"quant_min": 0,
90-
"quant_max": 255,
91-
},
92-
}
93-
smoothquant_minmax_activation_observer = {
94-
"name": "SmoothQuantActivationObserver",
95-
"smooth_quant_enabled": smooth_quant_enable,
96-
"dtype": "torch.quint8",
97-
"qscheme": "torch.per_tensor_affine",
98-
"reduce_range": False,
99-
"quant_min": 0,
100-
"quant_max": 255,
101-
"alpha": 0.5,
102-
"act_observer": minmax_activation_observer,
103-
"act_ic_observer": {
104-
"name": "PerChannelMinMaxObserver",
105-
"ch_axis": -1,
106-
"dtype": "torch.quint8",
107-
"qscheme": "torch.per_channel_affine",
108-
"reduce_range": False,
109-
"quant_min": 0,
110-
"quant_max": 255,
111-
},
112-
}
113-
REDUCE_RANGE = False if CpuInfo().vnni else True
114-
if REDUCE_RANGE:
115-
minmax_activation_observer["reduce_range"] = REDUCE_RANGE
116-
kl_activation_observer["reduce_range"] = REDUCE_RANGE
117-
if scheme == "sym":
118-
minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
119-
minmax_activation_observer["dtype"] = "torch.qint8"
120-
minmax_activation_observer["quant_min"] = -128
121-
minmax_activation_observer["quant_max"] = 127
122-
kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
123-
kl_activation_observer["dtype"] = "torch.qint8"
124-
kl_activation_observer["quant_min"] = -128
125-
kl_activation_observer["quant_max"] = 127
126-
if smooth_quant and smooth_quant_enable:
127-
if algorithm == "kl":
128-
return smoothquant_kl_activation_observer
129-
if algorithm == "minmax":
130-
return smoothquant_minmax_activation_observer
131-
else:
132-
if algorithm == "kl":
133-
return kl_activation_observer
134-
if algorithm == "minmax":
135-
return minmax_activation_observer
136-
137-
13841
def check_cfg_and_qconfig(
13942
tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name, smooth_quant=False
14043
): # pragma: no cover
@@ -223,131 +126,6 @@ def cfg_to_qconfig(
223126
return None
224127

225128

226-
def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
227-
"""Get all quantizable ops from model.
228-
229-
Args:
230-
model (object): input model
231-
example_inputs (dict|list|tuple|torch.Tensor): used to trace torch model.
232-
Returns:
233-
quantizable_ops (list): list of tuples of op_name and op_type.
234-
cfgs (dict): dict of configuration
235-
"""
236-
quantizable_ops = []
237-
# group ops by position for transform-based model
238-
detector = TransformerBasedModelBlockPatternDetector(model)
239-
detect_result = detector.detect_block()
240-
attention_block = detect_result.get("attention_blocks", None)
241-
ffn_blocks = detect_result.get("ffn_blocks", None)
242-
logger.info(f"Attention Blocks: {len(attention_block)}")
243-
logger.info(f"FFN Blocks: {len(ffn_blocks)}")
244-
if not os.path.exists(ipex_config_path):
245-
assert isinstance(model, torch.nn.Module), "The model passed in is not the instance of torch.nn.Module"
246-
247-
if hasattr(model, "save_qconf_summary"): # pragma: no cover
248-
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
249-
model.save_qconf_summary(qconf_summary=ipex_config_path)
250-
else:
251-
model.eval()
252-
253-
# create a quantization config file for intel pytorch extension model
254-
os.makedirs(os.path.dirname(ipex_config_path), exist_ok=True)
255-
assert example_inputs is not None, "IPEX need q_dataloader or example_inputs to prepare the model"
256-
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig
257-
258-
if ipex_ver.release >= Version("2.1").release:
259-
# HistogramObserver will cause a performance issue.
260-
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
261-
qconfig = QConfig(
262-
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
263-
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
264-
)
265-
from torch.ao.quantization import QConfigMapping
266-
267-
static_qconfig = QConfigMapping().set_global(qconfig)
268-
else:
269-
static_qconfig = QConfig(
270-
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
271-
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
272-
)
273-
274-
if isinstance(example_inputs, dict):
275-
model = ipex.quantization.prepare(model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=True)
276-
else:
277-
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=True)
278-
simple_inference(model, example_inputs, iterations=1)
279-
model.save_qconf_summary(qconf_summary=ipex_config_path)
280-
281-
map_op_name_to_fqn = {}
282-
with open(ipex_config_path, "r") as f:
283-
cfgs = json.load(f)
284-
if ipex_ver.release < Version("1.12.0").release: # pragma: no cover
285-
for op_cfg in cfgs:
286-
if op_cfg["name"] in unify_op_type_mapping_ipex:
287-
quantizable_ops.append((op_cfg["id"], unify_op_type_mapping_ipex[op_cfg["name"]]))
288-
else:
289-
re_flag = False
290-
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
291-
if re.match(pattern, op_cfg["name"]):
292-
re_flag = True
293-
quantizable_ops.append((op_cfg["id"], unify_op_type))
294-
break
295-
if not re_flag:
296-
quantizable_ops.append((op_cfg["id"], op_cfg["name"]))
297-
else:
298-
(
299-
ops_name,
300-
op_infos_from_cfgs,
301-
input_tensor_id_op_name,
302-
output_tensor_id_op_name,
303-
) = paser_cfgs(cfgs)
304-
quantizable_op_names = get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_id_op_name)
305-
for name in quantizable_op_names:
306-
# name : list
307-
if len(name) == 1:
308-
module_key = name[0][0]
309-
op_cfg_id = name[0][2]
310-
ipex_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
311-
module_fqn = cfgs[module_key]["q_op_infos"][op_cfg_id].get("fqn", None)
312-
313-
if ipex_op_type in unify_op_type_mapping_ipex:
314-
quantizable_ops.append((tuple(name), unify_op_type_mapping_ipex[ipex_op_type]))
315-
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
316-
else:
317-
re_flag = False
318-
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
319-
if re.match(pattern, ipex_op_type):
320-
re_flag = True
321-
quantizable_ops.append((tuple(name), unify_op_type))
322-
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
323-
break
324-
if not re_flag:
325-
quantizable_ops.append((tuple(name), ipex_op_type))
326-
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
327-
else:
328-
op_type = ""
329-
for op_name in name:
330-
module_key = op_name[0]
331-
op_cfg_id = op_name[2]
332-
single_op_type = cfgs[module_key]["q_op_infos"][op_cfg_id]["op_type"]
333-
if single_op_type in unify_op_type_mapping_ipex:
334-
single_op_type = unify_op_type_mapping_ipex[single_op_type]
335-
op_type += "&" + single_op_type if op_type else single_op_type
336-
quantizable_ops.append((tuple(name), op_type))
337-
_module_key = name[0][0]
338-
_op_cfg_id = name[0][2]
339-
module_fqn = cfgs[_module_key]["q_op_infos"][_op_cfg_id]["fqn"]
340-
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn
341-
342-
logger.debug("Map op name to fqn: ")
343-
logger.debug(map_op_name_to_fqn)
344-
logger.info("Attention Blocks : ")
345-
logger.info(attention_block)
346-
logger.info("FFN Blocks : ")
347-
logger.info(ffn_blocks)
348-
return quantizable_ops, cfgs, op_infos_from_cfgs, output_tensor_id_op_name
349-
350-
351129
def get_parent(node, all_parents=False): # pragma: no cover
352130
if node.inputs() is None:
353131
return None
@@ -2275,67 +2053,3 @@ def forward(self, x):
22752053
output = self.orig_layer(x)
22762054
self.output = output
22772055
return output
2278-
2279-
2280-
class CpuInfo(object): # pragma: no cover
2281-
"""Get CPU Info."""
2282-
2283-
def __init__(self):
2284-
"""Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket."""
2285-
self._bf16 = False
2286-
self._vnni = False
2287-
info = cpuinfo.get_cpu_info()
2288-
if "arch" in info and "X86" in info["arch"]:
2289-
cpuid = cpuinfo.CPUID()
2290-
max_extension_support = cpuid.get_max_extension_support()
2291-
if max_extension_support >= 7:
2292-
ecx = cpuid._run_asm(
2293-
b"\x31\xC9", # xor ecx, ecx
2294-
b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret
2295-
)
2296-
self._vnni = bool(ecx & (1 << 11))
2297-
eax = cpuid._run_asm(
2298-
b"\xB9\x01\x00\x00\x00", # mov ecx, 1
2299-
b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret
2300-
)
2301-
self._bf16 = bool(eax & (1 << 5))
2302-
if "arch" in info and "ARM" in info["arch"]: # pragma: no cover
2303-
self._sockets = 1
2304-
else:
2305-
self._sockets = self.get_number_of_sockets()
2306-
self._cores = psutil.cpu_count(logical=False)
2307-
self._cores_per_socket = int(self._cores / self._sockets)
2308-
2309-
@property
2310-
def bf16(self):
2311-
"""Get whether it is bf16."""
2312-
return self._bf16
2313-
2314-
@property
2315-
def vnni(self):
2316-
"""Get whether it is vnni."""
2317-
return self._vnni
2318-
2319-
@property
2320-
def cores_per_socket(self):
2321-
"""Get the cores per socket."""
2322-
return self._cores_per_socket
2323-
2324-
def get_number_of_sockets(self) -> int:
2325-
"""Get number of sockets in platform."""
2326-
cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l"
2327-
if psutil.WINDOWS:
2328-
cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"'
2329-
2330-
with subprocess.Popen(
2331-
args=cmd,
2332-
shell=True,
2333-
stdout=subprocess.PIPE,
2334-
stderr=subprocess.STDOUT,
2335-
universal_newlines=False,
2336-
) as proc:
2337-
proc.wait()
2338-
if proc.stdout:
2339-
for line in proc.stdout:
2340-
return int(line.decode("utf-8", errors="ignore").strip())
2341-
return 0

0 commit comments

Comments
 (0)