Skip to content

Commit e6a30a3

Browse files
Merge branch 'main' into ns/nf4_f8e4m3_proposal
2 parents fa63e40 + 3befef7 commit e6a30a3

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

optimum/commands/export/openvino.py

+37-33
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2222

2323
from ...exporters import TasksManager
24-
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
24+
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available, is_nncf_available
2525
from ...intel.utils.modeling_utils import _infer_library_from_model_name_or_path
2626
from ...utils.save_utils import maybe_load_preprocessors
2727
from ..base import BaseOptimumCLICommand, CommandInfo
@@ -343,40 +343,44 @@ def run(self):
343343
)
344344
elif self.args.weight_format in {"fp16", "fp32"}:
345345
ov_config = OVConfig(dtype=self.args.weight_format)
346-
elif self.args.weight_format is not None:
347-
# For int4 quantization if no parameter is provided, then use the default config if exists
348-
if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4":
349-
quantization_config = get_default_int4_config(self.args.model)
350-
else:
351-
quantization_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG)
352-
353-
if quantization_config.get("dataset", None) is not None:
354-
quantization_config["trust_remote_code"] = self.args.trust_remote_code
355-
ov_config = OVConfig(quantization_config=quantization_config)
356346
else:
357-
if self.args.dataset is None:
358-
raise ValueError(
359-
"Dataset is required for full quantization. Please provide it with --dataset argument."
360-
)
361-
362-
if self.args.quant_mode in ["nf4_f8e4m3", "int4_f8e4m3"]:
363-
wc_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG)
364-
weight_dtype_map = {"nf4_f8e4m3": "nf4", "int4_f8e4m3": "int4"}
365-
wc_config["dtype"] = weight_dtype_map[self.args.quant_mode]
366-
367-
q_config = prepare_q_config(self.args)
368-
q_config["dtype"] = "f8e4m3"
369-
370-
quantization_config = {
371-
"weight_quantization_config": wc_config,
372-
"full_quantization_config": q_config,
373-
"num_samples": self.args.num_samples,
374-
"dataset": self.args.dataset,
375-
"trust_remote_code": self.args.trust_remote_code,
376-
}
347+
if not is_nncf_available():
348+
raise ImportError("Applying quantization requires nncf, please install it with `pip install nncf`")
349+
350+
if self.args.weight_format is not None:
351+
# For int4 quantization if no parameter is provided, then use the default config if exists
352+
if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4":
353+
quantization_config = get_default_int4_config(self.args.model)
354+
else:
355+
quantization_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG)
356+
357+
if quantization_config.get("dataset", None) is not None:
358+
quantization_config["trust_remote_code"] = self.args.trust_remote_code
359+
ov_config = OVConfig(quantization_config=quantization_config)
377360
else:
378-
quantization_config = prepare_q_config(self.args)
379-
ov_config = OVConfig(quantization_config=quantization_config)
361+
if self.args.dataset is None:
362+
raise ValueError(
363+
"Dataset is required for full quantization. Please provide it with --dataset argument."
364+
)
365+
366+
if self.args.quant_mode in ["nf4_f8e4m3", "int4_f8e4m3"]:
367+
wc_config = prepare_wc_config(self.args, _DEFAULT_4BIT_CONFIG)
368+
weight_dtype_map = {"nf4_f8e4m3": "nf4", "int4_f8e4m3": "int4"}
369+
wc_config["dtype"] = weight_dtype_map[self.args.quant_mode]
370+
371+
q_config = prepare_q_config(self.args)
372+
q_config["dtype"] = "f8e4m3"
373+
374+
quantization_config = {
375+
"weight_quantization_config": wc_config,
376+
"full_quantization_config": q_config,
377+
"num_samples": self.args.num_samples,
378+
"dataset": self.args.dataset,
379+
"trust_remote_code": self.args.trust_remote_code,
380+
}
381+
else:
382+
quantization_config = prepare_q_config(self.args)
383+
ov_config = OVConfig(quantization_config=quantization_config)
380384

381385
quantization_config = ov_config.quantization_config if ov_config else None
382386
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None

0 commit comments

Comments
 (0)