43
43
"<class 'torch.nn.modules.conv.Conv2d'>" : "Conv2d" ,
44
44
"<class 'torch.nn.modules.conv.Conv3d'>" : "Conv3d" ,
45
45
"<class 'torch.nn.modules.activation.ReLU'>" : "ReLU" ,
46
+ "<class 'torch.nn.modules.sparse.EmbeddingBag'>" : "EmbeddingBag" ,
46
47
"<method 'add' of 'torch._C._TensorBase' objects>" : "add" , # for IPEX < 2.2
47
48
"<method 'add' of 'torch._C.TensorBase' objects>" : "add" , # for IPEX >= 2.2
48
49
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>" : "AdaptiveAvgPool2d" ,
49
50
"Linear_Relu" : "Linear" ,
51
+ "Linear_add" : "Linear" ,
50
52
"<class 'torch.nn.modules.linear.Linear'>" : "Linear" ,
51
53
"<class 'torch.nn.modules.pooling.MaxPool2d'>" : "MaxPool2d" ,
52
- "re" : {"<built-in method matmul of type object at" : "matmul" },
54
+ "re" : {
55
+ "<built-in method matmul of type object at" : "matmul" ,
56
+ "<built-in method add of type object at" : "add" ,
57
+ "<built-in method bmm of type object at" : "bmm" ,
58
+ },
53
59
}
54
60
55
61
BLOCK_PATTERNS = [
@@ -85,6 +91,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
85
91
Returns:
86
92
cfgs (dict): updated configs.
87
93
"""
94
+ ori_user_cfg = copy .deepcopy (user_cfg )
88
95
tmp_user_cfg = OrderedDict ()
89
96
for op in user_cfg : # map ipex op_name to pt op_name
90
97
for i , op_name in enumerate (op ):
@@ -94,9 +101,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
94
101
ori_op = (tuple (ops ), unify_op_type_mapping_ipex [op_infos_from_cfgs [ops ]["op_type" ]])
95
102
tmp_user_cfg [((ori_op [0 ],), ori_op [1 ])] = user_cfg [op ]
96
103
break
97
- user_cfg = tmp_user_cfg
98
- for op_name in user_cfg :
99
- inc_op_cfg = user_cfg [op_name ]
104
+
105
+ for op_name in tmp_user_cfg :
106
+ inc_op_cfg = tmp_user_cfg [op_name ]
100
107
for i , name in enumerate (op_name [0 ]):
101
108
# to int8
102
109
ipex_op_cfg = op_infos_from_cfgs [name ]
@@ -154,7 +161,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
154
161
else :
155
162
pass
156
163
cfgs [name [0 ]][name [1 ]][name [2 ]] = ipex_op_cfg
157
- return cfgs , user_cfg
164
+ return cfgs , ori_user_cfg
158
165
159
166
160
167
def generate_activation_observer (scheme , algorithm , smooth_quant = False , smooth_quant_enable = False ): # pragma: no cover
@@ -333,8 +340,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
333
340
elif "method" in ipex_op_type : # "<method 'add' of 'torch._C._TensorBase' objects>"
334
341
method = ipex_op_type .split ("'" )[1 ]
335
342
op_name_info .append ((module_fqn , method ))
336
- elif "Convolution " in ipex_op_type : # "Convolution_Relu"
337
- op_name_info .append ((module_fqn , "Conv2d" ))
343
+ elif "_ " in ipex_op_type : # "Convolution_Relu", "Linear_Relu "
344
+ op_name_info .append ((module_fqn , ipex_op_type . split ( "_" )[ 0 ] ))
338
345
else :
339
346
re_flag = False
340
347
for pattern , unify_op_type in unify_op_type_mapping_ipex ["re" ].items ():
@@ -394,32 +401,7 @@ def dump_model_op_stats(user_cfg):
394
401
"""
395
402
res = dict ()
396
403
for k , v in user_cfg .items ():
397
- op_type_list = k [- 1 ].split ("><" )
398
- op_type = ""
399
- for op in op_type_list :
400
- if "class" in op :
401
- op_type = (
402
- op [op .rfind ("." ) + 1 : op .rfind ("'" )]
403
- if op_type == ""
404
- else op_type + "&" + op [op .rfind ("." ) + 1 : op .rfind ("'" )]
405
- )
406
- elif "method" in op :
407
- start = op .find ("'" ) + 1
408
- if start > 1 :
409
- op_type = (
410
- op [start : op .find ("'" , start )]
411
- if op_type == ""
412
- else op_type + "&" + op [start : op .find ("'" , start )]
413
- )
414
- else :
415
- start = op .find ("method" ) + 7
416
- op_type = (
417
- op [start : op .find (" " , start )]
418
- if op_type == ""
419
- else op_type + "&" + op [start : op .find (" " , start )]
420
- )
421
- else :
422
- op_type = op if op_type == "" else op_type + "&" + op
404
+ op_type = k [1 ]
423
405
if op_type not in res .keys ():
424
406
res [op_type ] = {"INT8" : 0 , "BF16" : 0 , "FP32" : 0 }
425
407
if v ["weight" ]["dtype" ] == "int8" :
0 commit comments