Skip to content

Commit a58638c

Browse files
Enhance the set_local for operator type (#1745)
Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fdb5097 commit a58638c

File tree

4 files changed

+86
-12
lines changed

4 files changed

+86
-12
lines changed

neural_compressor/common/base_config.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def local_config(self):
198198
def local_config(self, config):
199199
self._local_config = config
200200

201-
def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
201+
def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig:
202202
if operator_name in self.local_config:
203203
logger.warning("The configuration for %s has already been set, update it.", operator_name)
204204
self.local_config[operator_name] = config
@@ -392,14 +392,16 @@ def _get_op_name_op_type_config(self):
392392
op_name_config_dict = dict()
393393
for name, config in self.local_config.items():
394394
if self._is_op_type(name):
395-
op_type_config_dict[name] = config
395+
# Convert the Callable to String.
396+
new_name = self._op_type_to_str(name)
397+
op_type_config_dict[new_name] = config
396398
else:
397399
op_name_config_dict[name] = config
398400
return op_type_config_dict, op_name_config_dict
399401

400402
def to_config_mapping(
401403
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
402-
) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]:
404+
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
403405
config_mapping = OrderedDict()
404406
if config_list is None:
405407
config_list = [self]
@@ -416,6 +418,14 @@ def to_config_mapping(
416418
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
417419
return config_mapping
418420

421+
@staticmethod
422+
def _op_type_to_str(op_type: Callable) -> str:
423+
# * Ort and TF may override this method.
424+
op_type_name = getattr(op_type, "__name__", "")
425+
if op_type_name == "":
426+
logger.warning("The op_type %s has no attribute __name__.", op_type)
427+
return op_type_name
428+
419429
@staticmethod
420430
def _is_op_type(name: str) -> bool:
421431
# * Ort and TF may override this method.

neural_compressor/torch/utils/utility.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ def set_module(model, op_name, new_module):
101101
setattr(second_last_module, name_list[-1], new_module)
102102

103103

104-
def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, Callable]]:
104+
def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
105105
module_dict = dict(model.named_modules())
106106
filter_result = []
107107
filter_result_set = set()
108108
for op_name, module in module_dict.items():
109109
if isinstance(module, tuple(white_module_list)):
110-
pair = (op_name, type(module))
110+
pair = (op_name, type(module).__name__)
111111
if pair not in filter_result_set:
112112
filter_result_set.add(pair)
113113
filter_result.append(pair)

test/3x/common/test_common.py

+49
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,29 @@ def __repr__(self) -> str:
7575
return "FakeModel"
7676

7777

78+
class FakeOpType:
79+
def __init__(self) -> None:
80+
self.name = "fake_module"
81+
82+
def __call__(self, x) -> Any:
83+
return x
84+
85+
def __repr__(self) -> str:
86+
return "FakeModule"
87+
88+
89+
class OP_TYPE1(FakeOpType):
90+
pass
91+
92+
93+
class OP_TYPE2(FakeOpType):
94+
pass
95+
96+
97+
def build_simple_fake_model():
98+
return FakeModel()
99+
100+
78101
@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO)
79102
class FakeAlgoConfig(BaseConfig):
80103
"""Config class for fake algo."""
@@ -257,6 +280,32 @@ def test_mixed_two_algos(self):
257280
self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping])
258281
self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping])
259282

283+
def test_set_local_op_name(self):
284+
quant_config = FakeAlgoConfig(weight_bits=4)
285+
# set `OP1_NAME`
286+
fc1_config = FakeAlgoConfig(weight_bits=6)
287+
quant_config.set_local("OP1_NAME", fc1_config)
288+
model_info = FAKE_MODEL_INFO
289+
logger.info(quant_config)
290+
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
291+
logger.info(configs_mapping)
292+
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
293+
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 4)
294+
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)
295+
296+
def test_set_local_op_type(self):
297+
quant_config = FakeAlgoConfig(weight_bits=4)
298+
# set all `OP_TYPE1`
299+
fc1_config = FakeAlgoConfig(weight_bits=6)
300+
quant_config.set_local(OP_TYPE1, fc1_config)
301+
model_info = FAKE_MODEL_INFO
302+
logger.info(quant_config)
303+
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
304+
logger.info(configs_mapping)
305+
self.assertTrue(configs_mapping[("OP1_NAME", "OP_TYPE1")].weight_bits == 6)
306+
self.assertTrue(configs_mapping[("OP2_NAME", "OP_TYPE1")].weight_bits == 6)
307+
self.assertTrue(configs_mapping[("OP3_NAME", "OP_TYPE2")].weight_bits == 4)
308+
260309

261310
class TestConfigSet(unittest.TestCase):
262311
def setUp(self):

test/3x/torch/test_config.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def test_config_white_lst2(self):
147147
logger.info(quant_config)
148148
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
149149
logger.info(configs_mapping)
150-
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
151-
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
150+
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
151+
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)
152152

153153
def test_config_from_dict(self):
154154
quant_config = {
@@ -253,16 +253,31 @@ def test_config_mapping(self):
253253
logger.info(quant_config)
254254
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
255255
logger.info(configs_mapping)
256-
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 6)
257-
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 4)
256+
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
257+
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 4)
258258
# test regular matching
259259
fc_config = RTNConfig(bits=5, dtype="int8")
260260
quant_config.set_local("fc", fc_config)
261261
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
262262
logger.info(configs_mapping)
263-
self.assertTrue(configs_mapping[("fc1", torch.nn.Linear)].bits == 5)
264-
self.assertTrue(configs_mapping[("fc2", torch.nn.Linear)].bits == 5)
265-
self.assertTrue(configs_mapping[("fc3", torch.nn.Linear)].bits == 5)
263+
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 5)
264+
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 5)
265+
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 5)
266+
267+
def test_set_local_op_type(self):
268+
quant_config = RTNConfig(bits=4, dtype="nf4")
269+
# set all `Linear`
270+
fc1_config = RTNConfig(bits=6, dtype="int8")
271+
quant_config.set_local(torch.nn.Linear, fc1_config)
272+
# get model and quantize
273+
fp32_model = build_simple_torch_model()
274+
model_info = get_model_info(fp32_model, white_module_list=[torch.nn.Linear])
275+
logger.info(quant_config)
276+
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
277+
logger.info(configs_mapping)
278+
self.assertTrue(configs_mapping[("fc1", "Linear")].bits == 6)
279+
self.assertTrue(configs_mapping[("fc2", "Linear")].bits == 6)
280+
self.assertTrue(configs_mapping[("fc3", "Linear")].bits == 6)
266281

267282
def test_gptq_config(self):
268283
gptq_config1 = GPTQConfig(bits=8, act_order=True)

0 commit comments

Comments
 (0)