|
46 | 46 | )
|
47 | 47 | from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger
|
48 | 48 | from neural_compressor.torch.utils.constants import (
|
| 49 | + LM_HEAD_NAMES, |
49 | 50 | PRIORITY_AUTOROUND,
|
50 | 51 | PRIORITY_AWQ,
|
51 | 52 | PRIORITY_GPTQ,
|
@@ -198,8 +199,7 @@ def to_config_mapping(
|
198 | 199 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
|
199 | 200 | ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
|
200 | 201 | if not self.quant_lm_head:
|
201 |
| - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] |
202 |
| - self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32")) |
| 202 | + self.set_local(LM_HEAD_NAMES, RTNConfig(dtype="fp32")) |
203 | 203 | config_mapping = super().to_config_mapping(config_list, model_info)
|
204 | 204 | return config_mapping
|
205 | 205 |
|
@@ -359,8 +359,7 @@ def to_config_mapping(
|
359 | 359 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
|
360 | 360 | ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
|
361 | 361 | if not self.quant_lm_head:
|
362 |
| - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] |
363 |
| - self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32")) |
| 362 | + self.set_local(LM_HEAD_NAMES, GPTQConfig(dtype="fp32")) |
364 | 363 | config_mapping = super().to_config_mapping(config_list, model_info)
|
365 | 364 | return config_mapping
|
366 | 365 |
|
@@ -502,8 +501,7 @@ def to_config_mapping(
|
502 | 501 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
|
503 | 502 | ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
|
504 | 503 | if not self.quant_lm_head:
|
505 |
| - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] |
506 |
| - self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32")) |
| 504 | + self.set_local(LM_HEAD_NAMES, AWQConfig(dtype="fp32")) |
507 | 505 | config_mapping = super().to_config_mapping(config_list, model_info)
|
508 | 506 | return config_mapping
|
509 | 507 |
|
@@ -641,8 +639,7 @@ def to_config_mapping(
|
641 | 639 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
|
642 | 640 | ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
|
643 | 641 | if not self.quant_lm_head:
|
644 |
| - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] |
645 |
| - self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32")) |
| 642 | + self.set_local(LM_HEAD_NAMES, TEQConfig(dtype="fp32")) |
646 | 643 | config_mapping = super().to_config_mapping(config_list, model_info)
|
647 | 644 | return config_mapping
|
648 | 645 |
|
@@ -1269,8 +1266,7 @@ def to_config_mapping(
|
1269 | 1266 | self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
|
1270 | 1267 | ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
|
1271 | 1268 | if not self.quant_lm_head:
|
1272 |
| - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] |
1273 |
| - self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32")) |
| 1269 | + self.set_local(LM_HEAD_NAMES, HQQConfig(dtype="fp32")) |
1274 | 1270 | config_mapping = super().to_config_mapping(config_list, model_info)
|
1275 | 1271 | return config_mapping
|
1276 | 1272 |
|
|
0 commit comments