Skip to content

Commit 2351abb

Browse files
committed
update per review
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent ff99765 commit 2351abb

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

neural_compressor/torch/quantization/config.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger
4848
from neural_compressor.torch.utils.constants import (
49+
LM_HEAD_NAMES,
4950
PRIORITY_AUTOROUND,
5051
PRIORITY_AWQ,
5152
PRIORITY_GPTQ,
@@ -198,8 +199,7 @@ def to_config_mapping(
198199
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
199200
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
200201
if not self.quant_lm_head:
201-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
202-
self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32"))
202+
self.set_local(LM_HEAD_NAMES, RTNConfig(dtype="fp32"))
203203
config_mapping = super().to_config_mapping(config_list, model_info)
204204
return config_mapping
205205

@@ -359,8 +359,7 @@ def to_config_mapping(
359359
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
360360
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
361361
if not self.quant_lm_head:
362-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
363-
self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32"))
362+
self.set_local(LM_HEAD_NAMES, GPTQConfig(dtype="fp32"))
364363
config_mapping = super().to_config_mapping(config_list, model_info)
365364
return config_mapping
366365

@@ -502,8 +501,7 @@ def to_config_mapping(
502501
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
503502
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
504503
if not self.quant_lm_head:
505-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
506-
self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32"))
504+
self.set_local(LM_HEAD_NAMES, AWQConfig(dtype="fp32"))
507505
config_mapping = super().to_config_mapping(config_list, model_info)
508506
return config_mapping
509507

@@ -641,8 +639,7 @@ def to_config_mapping(
641639
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
642640
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
643641
if not self.quant_lm_head:
644-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
645-
self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32"))
642+
self.set_local(LM_HEAD_NAMES, TEQConfig(dtype="fp32"))
646643
config_mapping = super().to_config_mapping(config_list, model_info)
647644
return config_mapping
648645

@@ -1269,8 +1266,7 @@ def to_config_mapping(
12691266
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
12701267
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
12711268
if not self.quant_lm_head:
1272-
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
1273-
self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32"))
1269+
self.set_local(LM_HEAD_NAMES, HQQConfig(dtype="fp32"))
12741270
config_mapping = super().to_config_mapping(config_list, model_info)
12751271
return config_mapping
12761272

neural_compressor/torch/utils/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,6 @@
6262
class LoadFormat(Enum):
6363
DEFAULT = "default"
6464
HUGGINGFACE = "huggingface"
65+
66+
67+
LM_HEAD_NAMES = [".*lm_head", ".*output_layer", ".*embed_out"]

0 commit comments

Comments
 (0)