Skip to content

Commit 503d9ef

Browse files
authored
Add op statistics dump for woq (#1876)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 5a0374e commit 503d9ef

File tree

4 files changed

+114
-3
lines changed

4 files changed

+114
-3
lines changed

neural_compressor/torch/algorithms/weight_only/utility.py

-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
1715
import torch
1816

1917
from neural_compressor.torch.utils import accelerator, device_synchronize, logger

neural_compressor/torch/quantization/algorithm_entry.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@
4545
StaticQuantConfig,
4646
TEQConfig,
4747
)
48-
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo
48+
from neural_compressor.torch.utils import (
49+
dump_model_op_stats,
50+
get_quantizer,
51+
is_ipex_imported,
52+
logger,
53+
postprocess_model,
54+
register_algo,
55+
)
4956
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT
5057

5158

@@ -89,6 +96,7 @@ def rtn_entry(
8996
model.qconfig = configs_mapping
9097
model.save = MethodType(save, model)
9198
postprocess_model(model, mode, quantizer)
99+
dump_model_op_stats(mode, configs_mapping)
92100
return model
93101

94102

@@ -141,6 +149,7 @@ def gptq_entry(
141149
model.qconfig = configs_mapping
142150
model.save = MethodType(save, model)
143151
postprocess_model(model, mode, quantizer)
152+
dump_model_op_stats(mode, configs_mapping)
144153

145154
return model
146155

@@ -361,6 +370,7 @@ def awq_quantize_entry(
361370
model.qconfig = configs_mapping
362371
model.save = MethodType(save, model)
363372
postprocess_model(model, mode, quantizer)
373+
dump_model_op_stats(mode, configs_mapping)
364374
return model
365375

366376

@@ -415,6 +425,7 @@ def teq_quantize_entry(
415425
model.qconfig = configs_mapping
416426
model.save = MethodType(save, model)
417427
postprocess_model(model, mode, quantizer)
428+
dump_model_op_stats(mode, configs_mapping)
418429

419430
return model
420431

@@ -491,6 +502,7 @@ def autoround_quantize_entry(
491502
model.qconfig = configs_mapping
492503
model.save = MethodType(save, model)
493504
postprocess_model(model, mode, quantizer)
505+
dump_model_op_stats(mode, configs_mapping)
494506
return model
495507

496508

@@ -511,6 +523,7 @@ def hqq_entry(
511523
quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
512524
model = quantizer.execute(model, mode=mode)
513525
postprocess_model(model, mode, quantizer)
526+
dump_model_op_stats(mode, configs_mapping)
514527

515528
return model
516529

neural_compressor/torch/utils/utility.py

+99
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Callable, Dict, List, Tuple, Union
1717

1818
import torch
19+
from prettytable import PrettyTable
1920
from typing_extensions import TypeAlias
2021

2122
from neural_compressor.common.utils import LazyImport, Mode, logger
@@ -163,3 +164,101 @@ def postprocess_model(model, mode, quantizer):
163164
elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
164165
if getattr(model, "quantizer", False):
165166
del model.quantizer
167+
168+
169+
class Statistics: # pragma: no cover
170+
"""The statistics printer."""
171+
172+
def __init__(self, data, header, field_names, output_handle=logger.info):
173+
"""Init a Statistics object.
174+
175+
Args:
176+
data: The statistics data
177+
header: The table header
178+
field_names: The field names
179+
output_handle: The output logging method
180+
"""
181+
self.field_names = field_names
182+
self.header = header
183+
self.data = data
184+
self.output_handle = output_handle
185+
self.tb = PrettyTable(min_table_width=40)
186+
187+
def print_stat(self):
188+
"""Print the statistics."""
189+
valid_field_names = []
190+
for index, value in enumerate(self.field_names):
191+
if index < 2:
192+
valid_field_names.append(value)
193+
continue
194+
195+
if any(i[index] for i in self.data):
196+
valid_field_names.append(value)
197+
self.tb.field_names = valid_field_names
198+
for i in self.data:
199+
tmp_data = []
200+
for index, value in enumerate(i):
201+
if self.field_names[index] in valid_field_names:
202+
tmp_data.append(value)
203+
if any(tmp_data[1:]):
204+
self.tb.add_row(tmp_data)
205+
lines = self.tb.get_string().split("\n")
206+
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
207+
for i in lines:
208+
self.output_handle(i)
209+
210+
211+
def dump_model_op_stats(mode, tune_cfg):
212+
"""This is a function to dump quantizable ops of model to user.
213+
214+
Args:
215+
model (object): input model
216+
tune_cfg (dict): quantization config
217+
Returns:
218+
None
219+
"""
220+
if mode == Mode.PREPARE:
221+
return
222+
res = {}
223+
# collect all dtype info and build empty results with existing op_type
224+
dtype_set = set()
225+
for op, config in tune_cfg.items():
226+
op_type = op[1]
227+
config = config.to_dict()
228+
# import pdb; pdb.set_trace()
229+
if not config["dtype"] == "fp32":
230+
num_bits = config["bits"]
231+
group_size = config["group_size"]
232+
dtype_str = "A32W{}G{}".format(num_bits, group_size)
233+
dtype_set.add(dtype_str)
234+
dtype_set.add("FP32")
235+
dtype_list = list(dtype_set)
236+
dtype_list.sort()
237+
238+
for op, config in tune_cfg.items():
239+
config = config.to_dict()
240+
op_type = op[1]
241+
if op_type not in res.keys():
242+
res[op_type] = {dtype: 0 for dtype in dtype_list}
243+
244+
# fill in results with op_type and dtype
245+
for op, config in tune_cfg.items():
246+
config = config.to_dict()
247+
if config["dtype"] == "fp32":
248+
res[op_type]["FP32"] += 1
249+
else:
250+
num_bits = config["bits"]
251+
group_size = config["group_size"]
252+
dtype_str = "A32W{}G{}".format(num_bits, group_size)
253+
res[op_type][dtype_str] += 1
254+
255+
# update stats format for dump.
256+
field_names = ["Op Type", "Total"]
257+
field_names.extend(dtype_list)
258+
output_data = []
259+
for op_type in res.keys():
260+
field_results = [op_type, sum(res[op_type].values())]
261+
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
262+
output_data.append(field_results)
263+
264+
Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()

requirements_pt.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
numpy < 2.0
22
peft==0.10.0
3+
prettytable
34
psutil
45
py-cpuinfo
56
pydantic

0 commit comments

Comments
 (0)