Skip to content

Commit d5099a4

Browse files
authored
Merge branch 'main' into ea/precise_act_scale
2 parents 99586b1 + 248aabd commit d5099a4

11 files changed

+557
-93
lines changed

optimum/exporters/openvino/__main__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,11 @@ def main_export(
274274
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
275275
)
276276

277-
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
277+
if (
278+
is_transformers_version(">=", "4.36")
279+
and is_transformers_version("<=", "4.45.0")
280+
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
281+
):
278282
loading_kwargs["attn_implementation"] = "eager"
279283

280284
# some models force flash_attn attention by default that does not support load model on cpu

optimum/exporters/openvino/convert.py

+63-6
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
from openvino.tools.ovc import convert_model
2929
from optimum.exporters import TasksManager
3030
from optimum.exporters.utils import (
31-
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
31+
DECODER_NAME,
32+
ENCODER_NAME,
33+
_get_submodels_for_export_encoder_decoder,
34+
get_diffusion_models_for_export,
3235
)
3336
from optimum.exporters.utils import (
34-
get_diffusion_models_for_export,
37+
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
3538
)
3639
from optimum.intel.utils.import_utils import (
3740
_diffusers_version,
@@ -103,10 +106,16 @@ def _set_runtime_options(
103106
_, sub_export_config = models_and_export_configs[model_name]
104107
if not hasattr(sub_export_config, "runtime_options"):
105108
sub_export_config.runtime_options = {}
106-
if "text-generation" in task or ("image-text-to-text" in task and model_name == "language_model"):
109+
if (
110+
"text-generation" in task
111+
or ("image-text-to-text" in task and model_name == "language_model")
112+
or getattr(sub_export_config, "stateful", False)
113+
):
107114
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
108115
if not quantized_model and (
109-
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
116+
"text-generation" in task
117+
or ("image-text-to-text" in task and model_name == "language_model")
118+
or getattr(sub_export_config, "stateful", False)
110119
):
111120
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"
112121

@@ -639,10 +648,14 @@ def export_from_model(
639648

640649
logger.info(f"Automatic task detection to: {task}.")
641650

651+
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
652+
model_type = getattr(getattr(model, "config", {}), "model_type", "")
642653
stateful = stateful and (
643-
ensure_export_task_support_stateful(task)
644-
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
654+
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
645655
)
656+
657+
if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
658+
stateful = False
646659
# TODO: support onnx_config.py in the model repo
647660
if custom_architecture and custom_export_configs is None:
648661
raise ValueError(
@@ -684,6 +697,11 @@ def export_from_model(
684697
if library_name == "diffusers":
685698
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
686699
stateful_submodels = False
700+
elif stateful and is_encoder_decoder and not custom_architecture:
701+
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
702+
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
703+
)
704+
stateful_submodels = [False, True]
687705
else:
688706
logging.disable(logging.INFO)
689707
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
@@ -1204,3 +1222,42 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12041222
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
12051223

12061224
return models_for_export
1225+
1226+
1227+
def _get_encoder_decoder_stateful_models_for_export(
1228+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1229+
task: str,
1230+
_variant: str,
1231+
library_name: str,
1232+
int_dtype: str = "int64",
1233+
float_dtype: str = "fp32",
1234+
preprocessors: Optional[List[Any]] = None,
1235+
):
1236+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1237+
model=model, exporter="openvino", task=task, library_name=library_name
1238+
)
1239+
export_config = export_config_constructor(
1240+
model.config,
1241+
int_dtype=int_dtype,
1242+
float_dtype=float_dtype,
1243+
preprocessors=preprocessors,
1244+
legacy=False,
1245+
)
1246+
1247+
export_config.variant = _variant
1248+
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
1249+
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")
1250+
1251+
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)
1252+
1253+
encoder_export_config = export_config.with_behavior("encoder")
1254+
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)
1255+
1256+
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
1257+
1258+
decoder_export_config_with_past.stateful = True
1259+
models_for_export[DECODER_NAME] = (
1260+
models_for_export[DECODER_NAME],
1261+
decoder_export_config_with_past,
1262+
)
1263+
return None, models_for_export

optimum/exporters/openvino/model_configs.py

+70
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

23+
from optimum.exporters.onnx.base import ConfigBehavior
2324
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2425
from optimum.exporters.onnx.model_configs import (
2526
CLIPOnnxConfig,
@@ -38,8 +39,10 @@
3839
MistralOnnxConfig,
3940
MPTOnnxConfig,
4041
PhiOnnxConfig,
42+
T5OnnxConfig,
4143
UNetOnnxConfig,
4244
VisionOnnxConfig,
45+
WhisperOnnxConfig,
4346
)
4447
from optimum.exporters.onnx.model_patcher import ModelPatcher
4548
from optimum.exporters.tasks import TasksManager
@@ -102,6 +105,7 @@
102105
Qwen2VLVisionEmbMergerPatcher,
103106
QwenModelPatcher,
104107
RotaryEmbPatcher,
108+
StatefulSeq2SeqDecoderPatcher,
105109
UpdateCausalMaskModelPatcher,
106110
XverseModelPatcher,
107111
)
@@ -2611,3 +2615,69 @@ def patch_model_for_export(
26112615
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
26122616
) -> "ModelPatcher":
26132617
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)
2618+
2619+
2620+
@register_in_tasks_manager(
2621+
"whisper",
2622+
*[
2623+
"feature-extraction",
2624+
"feature-extraction-with-past",
2625+
"audio-classification",
2626+
"automatic-speech-recognition",
2627+
"automatic-speech-recognition-with-past",
2628+
],
2629+
library_name="transformers",
2630+
)
2631+
class WhisperOpenVINOConfig(WhisperOnnxConfig):
2632+
def patch_model_for_export(
2633+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2634+
) -> ModelPatcher:
2635+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2636+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2637+
return super().patch_model_for_export(model, model_kwargs)
2638+
2639+
@property
2640+
def inputs(self):
2641+
common_inputs = super().inputs
2642+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2643+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2644+
return common_inputs
2645+
2646+
2647+
@register_in_tasks_manager(
2648+
"t5",
2649+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2650+
library_name="transformers",
2651+
)
2652+
class T5OpenVINOConfig(T5OnnxConfig):
2653+
def patch_model_for_export(
2654+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2655+
) -> ModelPatcher:
2656+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2657+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2658+
return super().patch_model_for_export(model, model_kwargs)
2659+
2660+
@property
2661+
def inputs(self):
2662+
common_inputs = super().inputs
2663+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2664+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2665+
return common_inputs
2666+
2667+
2668+
@register_in_tasks_manager(
2669+
"mt5",
2670+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2671+
library_name="transformers",
2672+
)
2673+
class MT5OpenVINOConfig(T5OpenVINOConfig):
2674+
pass
2675+
2676+
2677+
@register_in_tasks_manager(
2678+
"longt5",
2679+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2680+
library_name="transformers",
2681+
)
2682+
class LongT5OpenVINOConfig(T5OpenVINOConfig):
2683+
pass

optimum/exporters/openvino/model_patcher.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2525
from transformers.utils import is_tf_available
2626

27-
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
27+
from optimum.exporters.onnx.model_patcher import (
28+
DecoderModelPatcher,
29+
ModelPatcher,
30+
Seq2SeqModelPatcher,
31+
override_arguments,
32+
)
2833
from optimum.intel.utils.import_utils import (
2934
_openvino_version,
3035
_torch_version,
@@ -3740,3 +3745,49 @@ def __exit__(self, exc_type, exc_value, traceback):
37403745
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
37413746
for layer in self._model.transformer.h:
37423747
layer.attn._attn = layer.attn._orig_attn
3748+
3749+
3750+
class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
3751+
def __init__(
3752+
self,
3753+
config: "OnnxConfig",
3754+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3755+
model_kwargs: Optional[Dict[str, Any]] = None,
3756+
):
3757+
model.__orig_forward = model.forward
3758+
3759+
@functools.wraps(model.__orig_forward)
3760+
def patched_forward(*args, **kwargs):
3761+
from transformers.cache_utils import EncoderDecoderCache
3762+
3763+
signature = inspect.signature(self.orig_forward)
3764+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
3765+
3766+
return_legacy_cache = False
3767+
pkv_in_args = False
3768+
legacy_pkv = None
3769+
if "past_key_values" in kwargs:
3770+
legacy_pkv = kwargs.pop("past_key_values", None)
3771+
sign_names = list(signature.parameters.keys())
3772+
pkv_argument_index = sign_names.index("past_key_values")
3773+
if legacy_pkv is None and len(args) > pkv_argument_index:
3774+
legacy_pkv = args[pkv_argument_index]
3775+
pkv_in_args = True
3776+
if legacy_pkv is not None:
3777+
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
3778+
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
3779+
return_legacy_cache = True
3780+
if not pkv_in_args:
3781+
kwargs["past_key_values"] = pkv
3782+
else:
3783+
args[pkv_argument_index] = pkv
3784+
3785+
outputs = model.__orig_forward(*args, **kwargs)
3786+
if return_legacy_cache:
3787+
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
3788+
3789+
return outputs
3790+
3791+
model.forward = patched_forward
3792+
3793+
super().__init__(config, model, model_kwargs)

optimum/exporters/openvino/stateful.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,95 @@ def ensure_stateful_is_available(warn=True):
190190
return True
191191

192192

193+
_ENCODER_DECODER_TASKS_WITH_PAST = (
194+
"automatic-speech-recognition",
195+
"text2text-generation",
196+
)
197+
198+
_DECODER_TASKS_WITH_PAST = ("text-generation",)
199+
200+
193201
def ensure_export_task_support_stateful(task: str):
194202
from optimum.exporters import TasksManager
195203

196204
task = TasksManager.map_from_synonym(task)
197-
return task in ["text-generation-with-past"]
205+
206+
is_stateful = (
207+
task.endswith("-with-past")
208+
and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST
209+
)
210+
return is_stateful
198211

199212

200213
def ensure_model_type_support_stateful(model_type: str):
201214
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS
202215

203216

204-
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"):
217+
def remove_parameters_by_names(model: ov.Model, names: list):
218+
parameters = [model.input(name).get_node() for name in names]
219+
for p in parameters:
220+
model.remove_parameter(p)
221+
222+
223+
def get_input_nodes(node):
224+
return [input.get_node() for input in node.input_values()]
225+
226+
227+
def find_dependent_nodes(model: ov.Model, sources: list):
228+
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
229+
result = set(sources)
230+
for node in model.get_ordered_ops():
231+
input_nodes = set(get_input_nodes(node))
232+
if input_nodes & result:
233+
result.add(node)
234+
return result
235+
236+
237+
def get_read_value_ops(model: ov.Model):
238+
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]
239+
240+
241+
def get_shape_of_ops(model: ov.Model):
242+
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]
243+
244+
245+
def get_consumer_nodes(node):
246+
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
247+
return {input.get_node() for input in consumer_inputs}
248+
249+
250+
def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
251+
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
252+
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
253+
other_nodes = find_dependent_nodes(model, other_inputs)
254+
source_dependent_nodes = find_dependent_nodes(model, sources)
255+
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
256+
nodes = source_dependent_nodes - other_nodes
257+
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
258+
return edge_nodes
259+
260+
261+
def insert_state_for_nodes(model: ov.Model, nodes):
262+
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
263+
outputs = sum((node.outputs() for node in nodes), [])
264+
for output in outputs:
265+
consumers = output.get_target_inputs()
266+
# FIXME: get_any_name is not reliable as tensor may not have any names
267+
variable_id = output.get_any_name()
268+
read_value = ov.runtime.opset13.read_value(output, variable_id)
269+
for consumer in consumers:
270+
consumer.replace_source_output(read_value.output(0))
271+
assign = ov.runtime.opset13.assign(read_value, variable_id)
272+
model.add_sinks([assign])
273+
274+
275+
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
276+
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
277+
return patch_stateful_encoder_decoder(config, ov_model)
278+
return patch_stateful_decoder(config, ov_model)
279+
280+
281+
def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
205282
"""
206283
Apply stateful transformation to model to hide key values inputs inside model.
207284
Select transformation parameters based on model architecture
@@ -236,3 +313,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name
236313
make_stateful(
237314
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
238315
)
316+
317+
318+
def patch_stateful_encoder_decoder(config, ov_model):
319+
encoder_key_value_input_names = [
320+
key.get_any_name()
321+
for key in ov_model.inputs
322+
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
323+
]
324+
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
325+
patch_stateful_decoder(config, ov_model)
326+
insert_state_for_nodes(
327+
ov_model,
328+
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
329+
)

0 commit comments

Comments
 (0)