13
13
# limitations under the License.
14
14
"""Rewrite the FP32 operators to FP16 or BF16 operators."""
15
15
16
+ from collections import defaultdict
16
17
from dataclasses import dataclass
17
18
from functools import partial
18
19
from typing import Any , Callable , Dict , List , Tuple
25
26
from torch .fx .subgraph_rewriter import Match
26
27
from typing_extensions import TypeAlias
27
28
28
- from neural_compressor .common import utils
29
+ from neural_compressor .common import logger , utils
29
30
30
31
# =============================================================================
31
32
# Search and replace patterns
@@ -50,25 +51,44 @@ class PatternPair:
50
51
51
52
# key: torch func
52
53
# value: the tuple of args
53
- FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , Tuple [torch .Tensor , ...]]
54
+ FuncArgsMappingType : TypeAlias = Dict [TorchFuncType , List [ Tuple [torch .Tensor , ...] ]]
54
55
55
56
56
- # Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
57
- # TODO: complete the mapping
57
+ # Align with xiq, as it relay on xiq's set_module_xx capability
58
58
FN_ARGS_MAPPING : FuncArgsMappingType = {
59
- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
60
- torch .nn .functional .linear : (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
59
+ # Note: ORDER is matter
60
+ torch .nn .functional .linear : [
61
+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )), # linear w/o bias
62
+ (torch .randn (0 , 0 ), torch .randn (0 , 0 ), torch .randn (0 )), # linear w/ bias
63
+ ],
64
+ torch .nn .functional .conv2d : [
65
+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 )), # conv2d w/o bias
66
+ (torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 , 1 , 1 , 1 ), torch .randn (1 )), # conv2d w/ bias
67
+ ],
68
+ torch .matmul : [
69
+ (torch .randn (0 , 0 ), torch .randn (0 , 0 )),
70
+ (torch .randn (0 , 0 , 0 ), torch .randn (0 , 0 , 0 )),
71
+ (torch .randn (0 , 0 , 0 , 0 ), torch .randn (0 , 0 , 0 , 0 )),
72
+ ],
61
73
}
62
- # TODO: complete the mapping
63
- FN_ATEN_OPS_MAPPING = {
64
- torch .nn .functional .linear : torch .ops .aten .linear .default ,
74
+
75
+ # module cls <-> function name
76
+ NN_MODULES_TO_NN_FN = {
77
+ torch .nn .Linear : torch .nn .functional .linear ,
78
+ torch .nn .Conv2d : torch .nn .functional .conv2d ,
65
79
}
66
80
81
+ # Use the mapping from xiq
82
+ FN_ATEN_OPS_MAPPING = xiq ._map_module_function_to_aten_operator_type ()
83
+
67
84
SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING .values ()
68
85
69
86
70
87
PatternRegistryType : TypeAlias = Dict [TorchFuncType , PatternPair ]
71
- HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {torch .float16 : {}, torch .bfloat16 : {}}
88
+ HALF_PRECISION_PATTERN_REGISTRY : Dict [torch .dtype , PatternRegistryType ] = {
89
+ torch .float16 : defaultdict (list ),
90
+ torch .bfloat16 : defaultdict (list ),
91
+ }
72
92
73
93
# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
74
94
# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
@@ -98,15 +118,18 @@ def replace_fn_wrapper(fn_args, fn):
98
118
99
119
100
120
def _register_pattern_pair (dtype : torch .dtype ) -> None :
101
- for fn , fn_args in FN_ARGS_MAPPING .items ():
102
- pattern_pair = pattern_factory (fn , fn_args )
103
- HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ] = pattern_pair
104
- utils .logger .info (
121
+ for fn , fn_args_lst in FN_ARGS_MAPPING .items ():
122
+ for fn_args in fn_args_lst :
123
+ logger .debug (f"Registering search and replace patterns for { fn } with args: { fn_args } ." )
124
+ pattern_pair = pattern_factory (fn , fn_args )
125
+ HALF_PRECISION_PATTERN_REGISTRY [dtype ][fn ].append (pattern_pair )
126
+ utils .logger .debug (
105
127
f"Registered { len (HALF_PRECISION_PATTERN_REGISTRY [dtype ])} search and replace patterns for { dtype } ."
106
128
)
107
129
108
130
109
131
_register_pattern_pair (torch .float16 )
132
+ _register_pattern_pair (torch .bfloat16 )
110
133
111
134
112
135
def get_filter_fn (node_list , fn ):
@@ -182,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):
182
205
183
206
def transformation (gm : torch .fx .GraphModule , node_candidate_list : List [str ], target_dtype : torch .dtype = torch .float16 ):
184
207
"""Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
185
- for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
186
- apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
187
- utils .logger .info ("Half precision conversion is done:" )
208
+ for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY [target_dtype ].values ():
209
+ for pattern_pair in pattern_pair_lst :
210
+ apply_single_pattern_pair (gm , pattern_pair , node_candidate_list )
211
+ utils .logger .info (f"Half precision conversion({ target_dtype } ) completed." )
188
212
if utils .level_name == "DEBUG" : # pragma: no cover
189
213
gm .print_readable (True )
190
214
@@ -201,11 +225,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
201
225
op_name_filters = []
202
226
for op_type_name , config in op_type_configs .items (): # pragma: no cover
203
227
op_type = getattr (torch .nn , op_type_name )
204
- if config .act_dtype == "fp16" : # pragma: no cover
228
+ if config .act_dtype in [ "fp16" , "bf16" ] : # pragma: no cover
205
229
filter = xpq ._get_module_type_filter (op_type )
206
230
op_type_filters .append (filter )
207
231
for op_name , config in op_name_configs .items ():
208
- if config .act_dtype == "fp16" : # pragma: no cover
232
+ if config .act_dtype in [ "fp16" , "bf16" ] : # pragma: no cover
209
233
filter = xpq ._get_module_name_filter (op_name )
210
234
op_name_filters .append (filter )
211
235
node_set_from_user_config = set ()
@@ -237,5 +261,7 @@ def get_half_precision_node_set(gm, config):
237
261
for node in possible_node_set :
238
262
if node .target in SUPPORTED_OPERATORS :
239
263
half_precision_node_set .add (node )
240
- utils .logger .info (f"Found { len (half_precision_node_set )} nodes to convert to half precision." )
264
+ utils .logger .info (
265
+ f"Found { len (half_precision_node_set )} nodes to convert to half precision: { half_precision_node_set } "
266
+ )
241
267
return half_precision_node_set
0 commit comments