16
16
import json
17
17
import os
18
18
import re
19
+ from collections import OrderedDict
19
20
from typing import Dict , List , Union
20
21
21
22
import torch
66
67
def cfg_to_qconfig (tune_cfg , cfgs , op_infos_from_cfgs , output_tensor_id_op_name ): # pragma: no cover
67
68
assert cfgs is not None , "No configure for IPEX int8 model..."
68
69
op_infos = copy .deepcopy (op_infos_from_cfgs )
69
- cfgs = check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
70
+ cfgs , user_cfg = check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
70
71
with open (ipex_config_path , "w" ) as write_f :
71
72
json .dump (cfgs , write_f , indent = 4 )
73
+ return user_cfg
72
74
73
75
74
76
def check_cfg_and_qconfig (user_cfg , cfgs , op_infos_from_cfgs , output_tensor_ids_op_name ): # pragma: no cover
@@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
83
85
Returns:
84
86
cfgs (dict): updated configs.
85
87
"""
88
+ tmp_user_cfg = OrderedDict ()
89
+ for op in user_cfg : # map ipex op_name to pt op_name
90
+ for i , op_name in enumerate (op ):
91
+ for ops , _ in op_infos_from_cfgs .items ():
92
+ if "fqn" in op_infos_from_cfgs [ops ].keys () and op_infos_from_cfgs [ops ]["fqn" ] == op_name :
93
+ ori_op = (tuple (ops ), unify_op_type_mapping_ipex [op_infos_from_cfgs [ops ]["op_type" ]])
94
+ tmp_user_cfg [((ori_op [0 ],), ori_op [1 ])] = user_cfg [op ]
95
+ break
96
+ user_cfg = tmp_user_cfg
86
97
for op_name in user_cfg :
87
98
inc_op_cfg = user_cfg [op_name ]
88
99
for i , name in enumerate (op_name [0 ]):
@@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
142
153
else :
143
154
pass
144
155
cfgs [name [0 ]][name [1 ]][name [2 ]] = ipex_op_cfg
145
- return cfgs
156
+ return cfgs , user_cfg
146
157
147
158
148
159
def generate_activation_observer (scheme , algorithm , smooth_quant = False , smooth_quant_enable = False ): # pragma: no cover
@@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
212
223
cfgs (dict): dict of configuration
213
224
"""
214
225
quantizable_ops = []
226
+ op_name_info = []
215
227
# group ops by position for transform-based model
216
228
detector = TransformerBasedModelBlockPatternDetector (model )
217
229
detect_result = detector .detect_block ()
@@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
277
289
if ipex_op_type in unify_op_type_mapping_ipex :
278
290
quantizable_ops .append ((tuple (name ), unify_op_type_mapping_ipex [ipex_op_type ]))
279
291
map_op_name_to_fqn [(tuple (name ), ipex_op_type )] = module_fqn
292
+ if "class" in ipex_op_type : # "<class 'torch.nn.modules.activation.ReLU'>"
293
+ op_type = ipex_op_type .split ("'" )[1 ]
294
+ op_name_info .append ((module_fqn , eval (op_type )))
295
+ elif "method" in ipex_op_type : # "<method 'add' of 'torch._C._TensorBase' objects>"
296
+ method = ipex_op_type .split ("'" )[1 ]
297
+ op_type = getattr (
298
+ torch ._C ._TensorBase if ipex_ver .release < Version ("2.2" ) else torch ._C .TensorBase , method
299
+ )
300
+ op_name_info .append ((module_fqn , op_type ))
301
+ else :
302
+ op_name_info .append ((module_fqn , op_type ))
280
303
else :
281
304
re_flag = False
282
305
for pattern , unify_op_type in unify_op_type_mapping_ipex ["re" ].items ():
283
306
if re .match (pattern , ipex_op_type ):
284
307
re_flag = True
285
308
quantizable_ops .append ((tuple (name ), unify_op_type ))
286
309
map_op_name_to_fqn [(tuple (name ), unify_op_type )] = module_fqn
310
+ op_name_info .append ((module_fqn , ipex_op_type ))
287
311
break
288
312
if not re_flag :
289
313
quantizable_ops .append ((tuple (name ), ipex_op_type ))
290
314
map_op_name_to_fqn [(tuple (name ), ipex_op_type )] = module_fqn
315
+ op_name_info .append ((module_fqn , ipex_op_type ))
291
316
else :
292
317
op_type = ""
293
318
for op_name in name :
@@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
302
327
_op_cfg_id = name [0 ][2 ]
303
328
module_fqn = cfgs [_module_key ]["q_op_infos" ][_op_cfg_id ]["fqn" ]
304
329
map_op_name_to_fqn [(tuple (name ), op_type )] = module_fqn
330
+ op_name_info .append ((module_fqn , op_type ))
305
331
306
332
logger .debug ("Map op name to fqn: " )
307
333
logger .debug (map_op_name_to_fqn )
308
334
logger .info ("Attention Blocks : " )
309
335
logger .info (attention_block )
310
336
logger .info ("FFN Blocks : " )
311
337
logger .info (ffn_blocks )
312
- return quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name
338
+ return quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , op_name_info
313
339
314
340
315
341
def simple_inference (q_model , example_inputs , iterations = 1 ):
@@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
323
349
q_model (example_inputs )
324
350
325
351
326
- def dump_model_op_stats (tune_cfg ):
352
+ def dump_model_op_stats (user_cfg ):
327
353
"""This is a function to dump quantizable ops of model to user.
328
354
329
355
Args:
330
- tune_cfg (dict): quantization config
356
+ user_cfg (dict): quantization config
331
357
Returns:
332
358
None
333
359
"""
334
360
res = dict ()
335
- for k , v in tune_cfg [ "op" ] .items ():
361
+ for k , v in user_cfg .items ():
336
362
op_type_list = k [- 1 ].split ("><" )
337
363
op_type = ""
338
364
for op in op_type_list :
0 commit comments