Skip to content

Commit 8656c26

Browse files
authored
Merge branch 'huggingface:main' into qwen
2 parents e75b45b + fe10aaa commit 8656c26

19 files changed

+753
-230
lines changed

docs/source/openvino/export.mdx

+4-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Check out the help for more options:
3131

3232
```text
3333
usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code]
34-
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8}]
34+
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8,f8e4m3,f8e5m2}]
3535
[--library {transformers,diffusers,timm,sentence_transformers,open_clip}]
3636
[--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym]
3737
[--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}]
@@ -67,10 +67,9 @@ Optional arguments:
6767
on your local machine arbitrary code present in the model repository.
6868
--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}
6969
The weight format of the exported model.
70-
--quant-mode {int8}
70+
--quant-mode {int8,f8e4m3,f8e5m2}
7171
Quantization precision mode. This is used for applying full model quantization including
72-
activations. The only currently supported choice is 'int8' for int8 quantization of both
73-
weights and activations.
72+
activations.
7473
--library {transformers,diffusers,timm,sentence_transformers,open_clip}
7574
The library used to load the model before export. If not provided, will attempt to infer the
7675
local checkpoint's library
@@ -166,7 +165,7 @@ Models larger than 1 billion parameters are exported to the OpenVINO format with
166165
</Tip>
167166

168167

169-
Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to `int8`. This will quantize both weights and activations of Linear, Convolutional and some other layers to int8. Currently this is only supported for speech-to-text models. Please see example below.
168+
Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to preffered precision. This will quantize both weights and activations of Linear, Convolutional and some other layers to selected mode. Please see example below.
170169

171170
```bash
172171
optimum-cli export openvino -m openai/whisper-large-v3-turbo --quant-mode int8 --dataset librispeech --num-samples 32 --smooth-quant-alpha 0.9 ./whisper-large-v3-turbo

optimum/commands/export/openvino.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ def parse_args_openvino(parser: "ArgumentParser"):
7878
optional_group.add_argument(
7979
"--quant-mode",
8080
type=str,
81-
choices=["int8"],
81+
choices=["int8", "f8e4m3", "f8e5m2"],
8282
default=None,
8383
help=(
8484
"Quantization precision mode. This is used for applying full model quantization including activations. "
85-
"The only currently supported choice is 'int8' for int8 quantization of both weights and activations."
8685
),
8786
)
8887
optional_group.add_argument(
@@ -365,9 +364,6 @@ def run(self):
365364
quantization_config["trust_remote_code"] = self.args.trust_remote_code
366365
ov_config = OVConfig(quantization_config=quantization_config)
367366
else:
368-
if self.args.quant_mode != "int8":
369-
raise ValueError("Only 'int8' quantization mode is currently supported.")
370-
371367
quantization_config = {
372368
"weight_format": self.args.quant_mode,
373369
"activation_format": self.args.quant_mode,

optimum/exporters/openvino/__main__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def infer_task(
8686
revision=revision,
8787
cache_dir=cache_dir,
8888
token=token,
89+
library_name=library_name,
8990
)
9091
except KeyError as e:
9192
raise KeyError(
@@ -274,7 +275,11 @@ def main_export(
274275
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
275276
)
276277

277-
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
278+
if (
279+
is_transformers_version(">=", "4.36")
280+
and is_transformers_version("<=", "4.45.0")
281+
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
282+
):
278283
loading_kwargs["attn_implementation"] = "eager"
279284

280285
# some models force flash_attn attention by default that does not support load model on cpu

optimum/exporters/openvino/convert.py

+88-46
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
from openvino.tools.ovc import convert_model
2929
from optimum.exporters import TasksManager
3030
from optimum.exporters.utils import (
31-
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
31+
DECODER_NAME,
32+
ENCODER_NAME,
33+
_get_submodels_for_export_encoder_decoder,
34+
get_diffusion_models_for_export,
3235
)
3336
from optimum.exporters.utils import (
34-
get_diffusion_models_for_export,
37+
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
3538
)
3639
from optimum.intel.utils.import_utils import (
3740
_diffusers_version,
@@ -43,7 +46,6 @@
4346
_torch_version,
4447
_transformers_version,
4548
compare_versions,
46-
is_diffusers_version,
4749
is_openvino_tokenizers_version,
4850
is_openvino_version,
4951
is_tokenizers_version,
@@ -101,15 +103,18 @@ def _set_runtime_options(
101103
):
102104
for model_name in models_and_export_configs.keys():
103105
_, sub_export_config = models_and_export_configs[model_name]
104-
sub_export_config.runtime_options = {}
106+
if not hasattr(sub_export_config, "runtime_options"):
107+
sub_export_config.runtime_options = {}
105108
if (
106-
"diffusers" in library_name
107-
or "text-generation" in task
109+
"text-generation" in task
108110
or ("image-text-to-text" in task and model_name == "language_model")
111+
or getattr(sub_export_config, "stateful", False)
109112
):
110113
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
111114
if not quantized_model and (
112-
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
115+
"text-generation" in task
116+
or ("image-text-to-text" in task and model_name == "language_model")
117+
or getattr(sub_export_config, "stateful", False)
113118
):
114119
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"
115120

@@ -642,10 +647,14 @@ def export_from_model(
642647

643648
logger.info(f"Automatic task detection to: {task}.")
644649

650+
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
651+
model_type = getattr(getattr(model, "config", {}), "model_type", "")
645652
stateful = stateful and (
646-
ensure_export_task_support_stateful(task)
647-
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
653+
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
648654
)
655+
656+
if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
657+
stateful = False
649658
# TODO: support onnx_config.py in the model repo
650659
if custom_architecture and custom_export_configs is None:
651660
raise ValueError(
@@ -687,6 +696,11 @@ def export_from_model(
687696
if library_name == "diffusers":
688697
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
689698
stateful_submodels = False
699+
elif stateful and is_encoder_decoder and not custom_architecture:
700+
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
701+
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
702+
)
703+
stateful_submodels = [False, True]
690704
else:
691705
logging.disable(logging.INFO)
692706
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
@@ -999,45 +1013,29 @@ def _get_submodels_and_export_configs(
9991013
def get_diffusion_models_for_export_ext(
10001014
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
10011015
):
1002-
if is_diffusers_version(">=", "0.29.0"):
1003-
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
1004-
1005-
sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
1006-
if is_diffusers_version(">=", "0.30.0"):
1007-
from diffusers import StableDiffusion3InpaintPipeline
1008-
1009-
sd3_pipes.append(StableDiffusion3InpaintPipeline)
1010-
1011-
is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
1012-
else:
1013-
is_sd3 = False
1014-
1015-
if is_diffusers_version(">=", "0.30.0"):
1016-
from diffusers import FluxPipeline
1017-
1018-
flux_pipes = [FluxPipeline]
1019-
1020-
if is_diffusers_version(">=", "0.31.0"):
1021-
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline
1022-
1023-
flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])
1024-
1025-
if is_diffusers_version(">=", "0.32.0"):
1026-
from diffusers import FluxFillPipeline
1027-
1028-
flux_pipes.append(FluxFillPipeline)
1029-
1030-
is_flux = isinstance(pipeline, tuple(flux_pipes))
1031-
else:
1032-
is_flux = False
1033-
1034-
if not is_sd3 and not is_flux:
1035-
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
1036-
if is_sd3:
1016+
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
1017+
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
1018+
is_flux = pipeline.__class__.__name__.startswith("Flux")
1019+
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
1020+
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")
1021+
1022+
if is_sd or is_sdxl or is_lcm:
1023+
models_for_export = get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
1024+
if is_sdxl and pipeline.vae.config.force_upcast:
1025+
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}
1026+
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "128.0"}
1027+
1028+
# only SD 2.1 has overflow issue, it uses different prediction_type than other models
1029+
if is_sd and pipeline.scheduler.config.prediction_type == "v_prediction":
1030+
models_for_export["vae_encoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1031+
models_for_export["vae_decoder"][1].runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1032+
1033+
elif is_sd3:
10371034
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1038-
else:
1035+
elif is_flux:
10391036
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1040-
1037+
else:
1038+
raise ValueError(f"Unsupported pipeline type `{pipeline.__class__.__name__}` provided")
10411039
return None, models_for_export
10421040

10431041

@@ -1135,6 +1133,7 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11351133
int_dtype=int_dtype,
11361134
float_dtype=float_dtype,
11371135
)
1136+
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
11381137
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
11391138

11401139
return models_for_export
@@ -1172,6 +1171,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11721171
transformer_export_config = export_config_constructor(
11731172
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
11741173
)
1174+
transformer_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
11751175
models_for_export["transformer"] = (transformer, transformer_export_config)
11761176

11771177
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
@@ -1187,6 +1187,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11871187
vae_encoder_export_config = vae_config_constructor(
11881188
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
11891189
)
1190+
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
11901191
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
11911192

11921193
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
@@ -1202,6 +1203,7 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12021203
vae_decoder_export_config = vae_config_constructor(
12031204
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
12041205
)
1206+
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
12051207
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
12061208

12071209
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
@@ -1218,6 +1220,46 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12181220
int_dtype=int_dtype,
12191221
float_dtype=float_dtype,
12201222
)
1223+
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
12211224
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
12221225

12231226
return models_for_export
1227+
1228+
1229+
def _get_encoder_decoder_stateful_models_for_export(
1230+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1231+
task: str,
1232+
_variant: str,
1233+
library_name: str,
1234+
int_dtype: str = "int64",
1235+
float_dtype: str = "fp32",
1236+
preprocessors: Optional[List[Any]] = None,
1237+
):
1238+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1239+
model=model, exporter="openvino", task=task, library_name=library_name
1240+
)
1241+
export_config = export_config_constructor(
1242+
model.config,
1243+
int_dtype=int_dtype,
1244+
float_dtype=float_dtype,
1245+
preprocessors=preprocessors,
1246+
legacy=False,
1247+
)
1248+
1249+
export_config.variant = _variant
1250+
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
1251+
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")
1252+
1253+
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)
1254+
1255+
encoder_export_config = export_config.with_behavior("encoder")
1256+
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)
1257+
1258+
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
1259+
1260+
decoder_export_config_with_past.stateful = True
1261+
models_for_export[DECODER_NAME] = (
1262+
models_for_export[DECODER_NAME],
1263+
decoder_export_config_with_past,
1264+
)
1265+
return None, models_for_export

optimum/exporters/openvino/model_configs.py

+70
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

23+
from optimum.exporters.onnx.base import ConfigBehavior
2324
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2425
from optimum.exporters.onnx.model_configs import (
2526
CLIPOnnxConfig,
@@ -38,8 +39,10 @@
3839
MistralOnnxConfig,
3940
MPTOnnxConfig,
4041
PhiOnnxConfig,
42+
T5OnnxConfig,
4143
UNetOnnxConfig,
4244
VisionOnnxConfig,
45+
WhisperOnnxConfig,
4346
)
4447
from optimum.exporters.onnx.model_patcher import ModelPatcher
4548
from optimum.exporters.tasks import TasksManager
@@ -102,6 +105,7 @@
102105
Qwen2VLVisionEmbMergerPatcher,
103106
QwenModelPatcher,
104107
RotaryEmbPatcher,
108+
StatefulSeq2SeqDecoderPatcher,
105109
UpdateCausalMaskModelPatcher,
106110
XverseModelPatcher,
107111
)
@@ -2611,3 +2615,69 @@ def patch_model_for_export(
26112615
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
26122616
) -> "ModelPatcher":
26132617
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)
2618+
2619+
2620+
@register_in_tasks_manager(
2621+
"whisper",
2622+
*[
2623+
"feature-extraction",
2624+
"feature-extraction-with-past",
2625+
"audio-classification",
2626+
"automatic-speech-recognition",
2627+
"automatic-speech-recognition-with-past",
2628+
],
2629+
library_name="transformers",
2630+
)
2631+
class WhisperOpenVINOConfig(WhisperOnnxConfig):
2632+
def patch_model_for_export(
2633+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2634+
) -> ModelPatcher:
2635+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2636+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2637+
return super().patch_model_for_export(model, model_kwargs)
2638+
2639+
@property
2640+
def inputs(self):
2641+
common_inputs = super().inputs
2642+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2643+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2644+
return common_inputs
2645+
2646+
2647+
@register_in_tasks_manager(
2648+
"t5",
2649+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2650+
library_name="transformers",
2651+
)
2652+
class T5OpenVINOConfig(T5OnnxConfig):
2653+
def patch_model_for_export(
2654+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2655+
) -> ModelPatcher:
2656+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2657+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2658+
return super().patch_model_for_export(model, model_kwargs)
2659+
2660+
@property
2661+
def inputs(self):
2662+
common_inputs = super().inputs
2663+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2664+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2665+
return common_inputs
2666+
2667+
2668+
@register_in_tasks_manager(
2669+
"mt5",
2670+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2671+
library_name="transformers",
2672+
)
2673+
class MT5OpenVINOConfig(T5OpenVINOConfig):
2674+
pass
2675+
2676+
2677+
@register_in_tasks_manager(
2678+
"longt5",
2679+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2680+
library_name="transformers",
2681+
)
2682+
class LongT5OpenVINOConfig(T5OpenVINOConfig):
2683+
pass

0 commit comments

Comments
 (0)