Skip to content

Commit dc32541

Browse files
xin3hexinhe3
andcommitted
[SW-20808] Make sure save&load format is an Enum object (#58)
* [SW-20808] Make sure save&load format is an Enum object Signed-off-by: Xin He <xinhe3@habana.ai> * Update save_load_entry.py --------- Signed-off-by: Xin He <xinhe3@habana.ai> Co-authored-by: Xin He <xinhe3@habana.ai> Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent 6d4a097 commit dc32541

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

neural_compressor/torch/algorithms/fp8_quant/save_load.py

+4
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def save_for_multi_devices(model, checkpoint_dir="saved_results", format="huggin
348348
checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
349349
format (str, optional): defaults to 'huggingface'.
350350
"""
351+
format = get_enum_from_format(format)
352+
assert format == SaveLoadFormat.HUGGINGFACE, (
353+
"Currently, only huggingface models are supported." + "Please set format='huggingface'."
354+
)
351355
from safetensors.torch import save_file as safe_save_file
352356
if format == SaveLoadFormat.VLLM:
353357
import transformers

neural_compressor/torch/algorithms/weight_only/save_load.py

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def save(model, output_dir="./saved_results", format=SaveLoadFormat.DEFAULT, **k
6161
- tokenizer (Tokenizer, optional): The tokenizer to be saved along with the model (only applicable for 'huggingface' format).
6262
- max_shard_size (str, optional): The maximum size for each shard (only applicable for 'huggingface' format). Defaults to "5GB".
6363
"""
64+
format = get_enum_from_format(format)
6465
os.makedirs(output_dir, exist_ok=True)
6566
cur_accelerator.synchronize()
6667
if format == SaveLoadFormat.HUGGINGFACE: # pragma: no cover
@@ -128,6 +129,7 @@ def load(model_name_or_path, original_model=None, format=SaveLoadFormat.DEFAULT,
128129
Returns:
129130
torch.nn.Module: quantized model
130131
"""
132+
format = get_enum_from_format(format)
131133
model_loader = WOQModelLoader(model_name_or_path, original_model, format, device, **kwargs)
132134
model = model_loader.load_woq_model()
133135
return model

neural_compressor/torch/quantization/save_load_entry.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
RTNConfig,
2727
TEQConfig,
2828
)
29-
from neural_compressor.torch.utils import SaveLoadFormat
29+
from neural_compressor.torch.utils import SaveLoadFormat, get_enum_from_format
3030

3131
config_name_mapping = {
3232
FP8_QUANT: FP8Config,
@@ -45,6 +45,7 @@ def save(model, checkpoint_dir="saved_results", format="default"):
4545
quantized by llm-compressor(https://github.com/vllm-project/llm-compressor).
4646
Defaults to "default".
4747
"""
48+
format = get_enum_from_format(format)
4849
config_mapping = model.qconfig
4950
config_object = config_mapping[next(iter(config_mapping))]
5051
# fp8_quant
@@ -104,7 +105,8 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
104105
Returns:
105106
The quantized model
106107
"""
107-
if format == SaveLoadFormat.DEFAULT.value:
108+
format = get_enum_from_format(format)
109+
if format == SaveLoadFormat.DEFAULT:
108110
from neural_compressor.common.base_config import ConfigRegistry
109111

110112
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(model_name_or_path)), "qconfig.json")
@@ -133,7 +135,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
133135
model_name_or_path, original_model, format=SaveLoadFormat.DEFAULT, device=device
134136
)
135137
return qmodel.to(device)
136-
elif format == SaveLoadFormat.HUGGINGFACE.value:
138+
elif format == SaveLoadFormat.HUGGINGFACE:
137139
import transformers
138140

139141
config = transformers.AutoConfig.from_pretrained(model_name_or_path, **kwargs)
@@ -156,4 +158,4 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
156158
qmodel = weight_only.load(model_name_or_path, format=SaveLoadFormat.HUGGINGFACE, device=device, **kwargs)
157159
return qmodel.to(device)
158160
else:
159-
raise ValueError("`format` in load function can only be 'huggingface' or 'default', but get {}".format(format))
161+
assert False, "This code path should never be reached."

neural_compressor/torch/utils/utility.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
detect_processor_type_based_on_hw,
3333
logger,
3434
)
35-
from neural_compressor.torch.utils import is_optimum_habana_available, is_transformers_imported
35+
from neural_compressor.torch.utils import is_optimum_habana_available, is_transformers_imported, SaveLoadFormat
3636

3737
if is_transformers_imported():
3838
import transformers
@@ -711,3 +711,16 @@ def forward_wrapper(model, input):
711711
else:
712712
output = model(input)
713713
return output
714+
715+
716+
def get_enum_from_format(format):
717+
"""Make sure Save&Load format is an Enum object."""
718+
if isinstance(format, SaveLoadFormat):
719+
return format
720+
for obj in SaveLoadFormat:
721+
if format == obj.value:
722+
return obj
723+
elif format.upper() == obj.name:
724+
return obj
725+
raise ValueError(
726+
f"Invalid format value ('{format}'). Enter one of [{[m.name for m in SaveLoadFormat]}]")

0 commit comments

Comments
 (0)