|
16 | 16 | from typing import Callable, Dict, List, Tuple, Union
|
17 | 17 |
|
18 | 18 | import torch
|
| 19 | +from prettytable import PrettyTable |
19 | 20 | from typing_extensions import TypeAlias
|
20 | 21 |
|
21 | 22 | from neural_compressor.common.utils import LazyImport, Mode, logger
|
@@ -163,3 +164,101 @@ def postprocess_model(model, mode, quantizer):
|
163 | 164 | elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
|
164 | 165 | if getattr(model, "quantizer", False):
|
165 | 166 | del model.quantizer
|
| 167 | + |
| 168 | + |
| 169 | +class Statistics: # pragma: no cover |
| 170 | + """The statistics printer.""" |
| 171 | + |
| 172 | + def __init__(self, data, header, field_names, output_handle=logger.info): |
| 173 | + """Init a Statistics object. |
| 174 | +
|
| 175 | + Args: |
| 176 | + data: The statistics data |
| 177 | + header: The table header |
| 178 | + field_names: The field names |
| 179 | + output_handle: The output logging method |
| 180 | + """ |
| 181 | + self.field_names = field_names |
| 182 | + self.header = header |
| 183 | + self.data = data |
| 184 | + self.output_handle = output_handle |
| 185 | + self.tb = PrettyTable(min_table_width=40) |
| 186 | + |
| 187 | + def print_stat(self): |
| 188 | + """Print the statistics.""" |
| 189 | + valid_field_names = [] |
| 190 | + for index, value in enumerate(self.field_names): |
| 191 | + if index < 2: |
| 192 | + valid_field_names.append(value) |
| 193 | + continue |
| 194 | + |
| 195 | + if any(i[index] for i in self.data): |
| 196 | + valid_field_names.append(value) |
| 197 | + self.tb.field_names = valid_field_names |
| 198 | + for i in self.data: |
| 199 | + tmp_data = [] |
| 200 | + for index, value in enumerate(i): |
| 201 | + if self.field_names[index] in valid_field_names: |
| 202 | + tmp_data.append(value) |
| 203 | + if any(tmp_data[1:]): |
| 204 | + self.tb.add_row(tmp_data) |
| 205 | + lines = self.tb.get_string().split("\n") |
| 206 | + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") |
| 207 | + for i in lines: |
| 208 | + self.output_handle(i) |
| 209 | + |
| 210 | + |
| 211 | +def dump_model_op_stats(mode, tune_cfg): |
| 212 | + """This is a function to dump quantizable ops of model to user. |
| 213 | +
|
| 214 | + Args: |
| 215 | + model (object): input model |
| 216 | + tune_cfg (dict): quantization config |
| 217 | + Returns: |
| 218 | + None |
| 219 | + """ |
| 220 | + if mode == Mode.PREPARE: |
| 221 | + return |
| 222 | + res = {} |
| 223 | + # collect all dtype info and build empty results with existing op_type |
| 224 | + dtype_set = set() |
| 225 | + for op, config in tune_cfg.items(): |
| 226 | + op_type = op[1] |
| 227 | + config = config.to_dict() |
| 228 | + # import pdb; pdb.set_trace() |
| 229 | + if not config["dtype"] == "fp32": |
| 230 | + num_bits = config["bits"] |
| 231 | + group_size = config["group_size"] |
| 232 | + dtype_str = "A32W{}G{}".format(num_bits, group_size) |
| 233 | + dtype_set.add(dtype_str) |
| 234 | + dtype_set.add("FP32") |
| 235 | + dtype_list = list(dtype_set) |
| 236 | + dtype_list.sort() |
| 237 | + |
| 238 | + for op, config in tune_cfg.items(): |
| 239 | + config = config.to_dict() |
| 240 | + op_type = op[1] |
| 241 | + if op_type not in res.keys(): |
| 242 | + res[op_type] = {dtype: 0 for dtype in dtype_list} |
| 243 | + |
| 244 | + # fill in results with op_type and dtype |
| 245 | + for op, config in tune_cfg.items(): |
| 246 | + config = config.to_dict() |
| 247 | + if config["dtype"] == "fp32": |
| 248 | + res[op_type]["FP32"] += 1 |
| 249 | + else: |
| 250 | + num_bits = config["bits"] |
| 251 | + group_size = config["group_size"] |
| 252 | + dtype_str = "A32W{}G{}".format(num_bits, group_size) |
| 253 | + res[op_type][dtype_str] += 1 |
| 254 | + |
| 255 | + # update stats format for dump. |
| 256 | + field_names = ["Op Type", "Total"] |
| 257 | + field_names.extend(dtype_list) |
| 258 | + output_data = [] |
| 259 | + for op_type in res.keys(): |
| 260 | + field_results = [op_type, sum(res[op_type].values())] |
| 261 | + field_results.extend([res[op_type][dtype] for dtype in dtype_list]) |
| 262 | + output_data.append(field_results) |
| 263 | + |
| 264 | + Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat() |
0 commit comments