26
26
from neural_compressor .torch .utils import (
27
27
HPU_SAFE_WEIGHTS_NAME ,
28
28
HPU_WEIGHT_NAME ,
29
+ LM_HEAD_NAMES ,
29
30
QCONFIG_NAME ,
30
31
WEIGHT_NAME ,
31
32
SaveLoadFormat ,
33
+ get_accelerator ,
34
+ get_enum_from_format ,
32
35
logger ,
33
36
set_module ,
34
- get_enum_from_format ,
35
- LM_HEAD_NAMES ,
36
- get_accelerator ,
37
37
)
38
38
39
39
from .modules import HPUWeightOnlyLinear , INCWeightOnlyLinear , MulLinear
@@ -966,8 +966,9 @@ def change_config_to_hf_format(config_mappings):
966
966
"true_sequential" : True ,
967
967
"model_name_or_path" : None ,
968
968
"model_file_base_name" : "model" ,
969
- "quant_method" : "gptq" # INC is using AutoGPTQ format for RTN, GPTQ, AWQ, and TEQ
969
+ "quant_method" : "gptq" , # INC is using AutoGPTQ format for RTN, GPTQ, AWQ, and TEQ
970
970
}
971
+
971
972
def _is_lm_head (name ):
972
973
for lm_head_name in LM_HEAD_NAMES :
973
974
if re .match (lm_head_name , name ):
@@ -994,17 +995,21 @@ def _is_lm_head(name):
994
995
else :
995
996
assert bits == config .bits , "bits should be the same for all modules, got {bits} and {config.bits}."
996
997
assert sym == config .use_sym , "sym should be the same for all modules, got {sym} and {config.use_sym}."
997
- assert group_size == config .group_size , \
998
- "group_size should be the same for all modules, got {group_size} and {config.group_size}."
998
+ assert (
999
+ group_size == config .group_size
1000
+ ), "group_size should be the same for all modules, got {group_size} and {config.group_size}."
999
1001
if hasattr (config , "percdamp" ):
1000
- assert damp_percent == config .percdamp , \
1001
- "percdamp should be the same for all modules, got {damp_percent} and {config.percdamp}."
1002
+ assert (
1003
+ damp_percent == config .percdamp
1004
+ ), "percdamp should be the same for all modules, got {damp_percent} and {config.percdamp}."
1002
1005
if hasattr (config , "act_order" ):
1003
- assert desc_act == config .act_order , \
1004
- "act_order should be the same for all modules, got {desc_act} and {config.act_order}."
1006
+ assert (
1007
+ desc_act == config .act_order
1008
+ ), "act_order should be the same for all modules, got {desc_act} and {config.act_order}."
1005
1009
if hasattr (config , "true_sequential" ):
1006
- assert true_sequential == config .true_sequential , \
1007
- "true_sequential should be the same for all modules, got {true_sequential} and {config.true_sequential}."
1010
+ assert (
1011
+ true_sequential == config .true_sequential
1012
+ ), "true_sequential should be the same for all modules, got {true_sequential} and {config.true_sequential}."
1008
1013
default_quantization_config ["bits" ] = bits
1009
1014
default_quantization_config ["group_size" ] = group_size
1010
1015
default_quantization_config ["damp_percent" ] = damp_percent
0 commit comments