26
26
from neural_compressor .torch .utils import (
27
27
HPU_SAFE_WEIGHTS_NAME ,
28
28
HPU_WEIGHT_NAME ,
29
+ LM_HEAD_NAMES ,
29
30
QCONFIG_NAME ,
30
31
WEIGHT_NAME ,
31
32
SaveLoadFormat ,
33
+ get_accelerator ,
34
+ get_enum_from_format ,
32
35
logger ,
33
36
set_module ,
34
- get_enum_from_format ,
35
- LM_HEAD_NAMES ,
36
- get_accelerator ,
37
37
)
38
38
39
39
from .modules import HPUWeightOnlyLinear , INCWeightOnlyLinear , MulLinear
@@ -964,8 +964,9 @@ def change_config_to_hf_format(config_mappings):
964
964
"true_sequential" : True ,
965
965
"model_name_or_path" : None ,
966
966
"model_file_base_name" : "model" ,
967
- "quant_method" : "gptq" # INC is using AutoGPTQ format for RTN, GPTQ, AWQ, and TEQ
967
+ "quant_method" : "gptq" , # INC is using AutoGPTQ format for RTN, GPTQ, AWQ, and TEQ
968
968
}
969
+
969
970
def _is_lm_head (name ):
970
971
for lm_head_name in LM_HEAD_NAMES :
971
972
if re .match (lm_head_name , name ):
@@ -992,17 +993,21 @@ def _is_lm_head(name):
992
993
else :
993
994
assert bits == config .bits , "bits should be the same for all modules, got {bits} and {config.bits}."
994
995
assert sym == config .use_sym , "sym should be the same for all modules, got {sym} and {config.use_sym}."
995
- assert group_size == config .group_size , \
996
- "group_size should be the same for all modules, got {group_size} and {config.group_size}."
996
+ assert (
997
+ group_size == config .group_size
998
+ ), "group_size should be the same for all modules, got {group_size} and {config.group_size}."
997
999
if hasattr (config , "percdamp" ):
998
- assert damp_percent == config .percdamp , \
999
- "percdamp should be the same for all modules, got {damp_percent} and {config.percdamp}."
1000
+ assert (
1001
+ damp_percent == config .percdamp
1002
+ ), "percdamp should be the same for all modules, got {damp_percent} and {config.percdamp}."
1000
1003
if hasattr (config , "act_order" ):
1001
- assert desc_act == config .act_order , \
1002
- "act_order should be the same for all modules, got {desc_act} and {config.act_order}."
1004
+ assert (
1005
+ desc_act == config .act_order
1006
+ ), "act_order should be the same for all modules, got {desc_act} and {config.act_order}."
1003
1007
if hasattr (config , "true_sequential" ):
1004
- assert true_sequential == config .true_sequential , \
1005
- "true_sequential should be the same for all modules, got {true_sequential} and {config.true_sequential}."
1008
+ assert (
1009
+ true_sequential == config .true_sequential
1010
+ ), "true_sequential should be the same for all modules, got {true_sequential} and {config.true_sequential}."
1006
1011
default_quantization_config ["bits" ] = bits
1007
1012
default_quantization_config ["group_size" ] = group_size
1008
1013
default_quantization_config ["damp_percent" ] = damp_percent
0 commit comments