Skip to content

Commit 74ee7eb

Browse files
Merge decoder and decoder with past to stateful for seq2seq (#1078)
* merge decoder and decoder with past to stateful for seq2seq * fix quantization * fix loading decoder_with_past * fix quant tests * fix tests * fix more tests * make input dynamic and enable sdpa * review comments and kv cache compression disable in fp * fix task recognition * fix quantization tests * respect from_onnx * update test to check that stateful expected * Apply suggestions from code review Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent fe55db5 commit 74ee7eb

11 files changed

+552
-91
lines changed

optimum/exporters/openvino/__main__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,11 @@ def main_export(
274274
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
275275
)
276276

277-
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
277+
if (
278+
is_transformers_version(">=", "4.36")
279+
and is_transformers_version("<=", "4.45.0")
280+
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
281+
):
278282
loading_kwargs["attn_implementation"] = "eager"
279283

280284
# some models force flash_attn attention by default that does not support load model on cpu

optimum/exporters/openvino/convert.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
from openvino.tools.ovc import convert_model
2929
from optimum.exporters import TasksManager
3030
from optimum.exporters.utils import (
31-
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
31+
DECODER_NAME,
32+
ENCODER_NAME,
33+
_get_submodels_for_export_encoder_decoder,
34+
get_diffusion_models_for_export,
3235
)
3336
from optimum.exporters.utils import (
34-
get_diffusion_models_for_export,
37+
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
3538
)
3639
from optimum.intel.utils.import_utils import (
3740
_diffusers_version,
@@ -106,10 +109,13 @@ def _set_runtime_options(
106109
"diffusers" in library_name
107110
or "text-generation" in task
108111
or ("image-text-to-text" in task and model_name == "language_model")
112+
or getattr(sub_export_config, "stateful", False)
109113
):
110114
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
111115
if not quantized_model and (
112-
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
116+
"text-generation" in task
117+
or ("image-text-to-text" in task and model_name == "language_model")
118+
or getattr(sub_export_config, "stateful", False)
113119
):
114120
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"
115121

@@ -642,10 +648,14 @@ def export_from_model(
642648

643649
logger.info(f"Automatic task detection to: {task}.")
644650

651+
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
652+
model_type = getattr(getattr(model, "config", {}), "model_type", "")
645653
stateful = stateful and (
646-
ensure_export_task_support_stateful(task)
647-
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
654+
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
648655
)
656+
657+
if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
658+
stateful = False
649659
# TODO: support onnx_config.py in the model repo
650660
if custom_architecture and custom_export_configs is None:
651661
raise ValueError(
@@ -687,6 +697,11 @@ def export_from_model(
687697
if library_name == "diffusers":
688698
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
689699
stateful_submodels = False
700+
elif stateful and is_encoder_decoder and not custom_architecture:
701+
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
702+
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
703+
)
704+
stateful_submodels = [False, True]
690705
else:
691706
logging.disable(logging.INFO)
692707
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
@@ -1221,3 +1236,42 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
12211236
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
12221237

12231238
return models_for_export
1239+
1240+
1241+
def _get_encoder_decoder_stateful_models_for_export(
1242+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1243+
task: str,
1244+
_variant: str,
1245+
library_name: str,
1246+
int_dtype: str = "int64",
1247+
float_dtype: str = "fp32",
1248+
preprocessors: Optional[List[Any]] = None,
1249+
):
1250+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1251+
model=model, exporter="openvino", task=task, library_name=library_name
1252+
)
1253+
export_config = export_config_constructor(
1254+
model.config,
1255+
int_dtype=int_dtype,
1256+
float_dtype=float_dtype,
1257+
preprocessors=preprocessors,
1258+
legacy=False,
1259+
)
1260+
1261+
export_config.variant = _variant
1262+
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
1263+
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")
1264+
1265+
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)
1266+
1267+
encoder_export_config = export_config.with_behavior("encoder")
1268+
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)
1269+
1270+
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
1271+
1272+
decoder_export_config_with_past.stateful = True
1273+
models_for_export[DECODER_NAME] = (
1274+
models_for_export[DECODER_NAME],
1275+
decoder_export_config_with_past,
1276+
)
1277+
return None, models_for_export

optimum/exporters/openvino/model_configs.py

+70
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

23+
from optimum.exporters.onnx.base import ConfigBehavior
2324
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2425
from optimum.exporters.onnx.model_configs import (
2526
CLIPOnnxConfig,
@@ -38,8 +39,10 @@
3839
MistralOnnxConfig,
3940
MPTOnnxConfig,
4041
PhiOnnxConfig,
42+
T5OnnxConfig,
4143
UNetOnnxConfig,
4244
VisionOnnxConfig,
45+
WhisperOnnxConfig,
4346
)
4447
from optimum.exporters.onnx.model_patcher import ModelPatcher
4548
from optimum.exporters.tasks import TasksManager
@@ -102,6 +105,7 @@
102105
Qwen2VLVisionEmbMergerPatcher,
103106
QwenModelPatcher,
104107
RotaryEmbPatcher,
108+
StatefulSeq2SeqDecoderPatcher,
105109
UpdateCausalMaskModelPatcher,
106110
XverseModelPatcher,
107111
)
@@ -2611,3 +2615,69 @@ def patch_model_for_export(
26112615
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
26122616
) -> "ModelPatcher":
26132617
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)
2618+
2619+
2620+
@register_in_tasks_manager(
2621+
"whisper",
2622+
*[
2623+
"feature-extraction",
2624+
"feature-extraction-with-past",
2625+
"audio-classification",
2626+
"automatic-speech-recognition",
2627+
"automatic-speech-recognition-with-past",
2628+
],
2629+
library_name="transformers",
2630+
)
2631+
class WhisperOpenVINOConfig(WhisperOnnxConfig):
2632+
def patch_model_for_export(
2633+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2634+
) -> ModelPatcher:
2635+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2636+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2637+
return super().patch_model_for_export(model, model_kwargs)
2638+
2639+
@property
2640+
def inputs(self):
2641+
common_inputs = super().inputs
2642+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2643+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2644+
return common_inputs
2645+
2646+
2647+
@register_in_tasks_manager(
2648+
"t5",
2649+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2650+
library_name="transformers",
2651+
)
2652+
class T5OpenVINOConfig(T5OnnxConfig):
2653+
def patch_model_for_export(
2654+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2655+
) -> ModelPatcher:
2656+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2657+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2658+
return super().patch_model_for_export(model, model_kwargs)
2659+
2660+
@property
2661+
def inputs(self):
2662+
common_inputs = super().inputs
2663+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2664+
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
2665+
return common_inputs
2666+
2667+
2668+
@register_in_tasks_manager(
2669+
"mt5",
2670+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2671+
library_name="transformers",
2672+
)
2673+
class MT5OpenVINOConfig(T5OpenVINOConfig):
2674+
pass
2675+
2676+
2677+
@register_in_tasks_manager(
2678+
"longt5",
2679+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2680+
library_name="transformers",
2681+
)
2682+
class LongT5OpenVINOConfig(T5OpenVINOConfig):
2683+
pass

optimum/exporters/openvino/model_patcher.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2525
from transformers.utils import is_tf_available
2626

27-
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
27+
from optimum.exporters.onnx.model_patcher import (
28+
DecoderModelPatcher,
29+
ModelPatcher,
30+
Seq2SeqModelPatcher,
31+
override_arguments,
32+
)
2833
from optimum.intel.utils.import_utils import (
2934
_openvino_version,
3035
_torch_version,
@@ -3740,3 +3745,49 @@ def __exit__(self, exc_type, exc_value, traceback):
37403745
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
37413746
for layer in self._model.transformer.h:
37423747
layer.attn._attn = layer.attn._orig_attn
3748+
3749+
3750+
class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
3751+
def __init__(
3752+
self,
3753+
config: "OnnxConfig",
3754+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3755+
model_kwargs: Optional[Dict[str, Any]] = None,
3756+
):
3757+
model.__orig_forward = model.forward
3758+
3759+
@functools.wraps(model.__orig_forward)
3760+
def patched_forward(*args, **kwargs):
3761+
from transformers.cache_utils import EncoderDecoderCache
3762+
3763+
signature = inspect.signature(self.orig_forward)
3764+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
3765+
3766+
return_legacy_cache = False
3767+
pkv_in_args = False
3768+
legacy_pkv = None
3769+
if "past_key_values" in kwargs:
3770+
legacy_pkv = kwargs.pop("past_key_values", None)
3771+
sign_names = list(signature.parameters.keys())
3772+
pkv_argument_index = sign_names.index("past_key_values")
3773+
if legacy_pkv is None and len(args) > pkv_argument_index:
3774+
legacy_pkv = args[pkv_argument_index]
3775+
pkv_in_args = True
3776+
if legacy_pkv is not None:
3777+
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
3778+
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
3779+
return_legacy_cache = True
3780+
if not pkv_in_args:
3781+
kwargs["past_key_values"] = pkv
3782+
else:
3783+
args[pkv_argument_index] = pkv
3784+
3785+
outputs = model.__orig_forward(*args, **kwargs)
3786+
if return_legacy_cache:
3787+
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
3788+
3789+
return outputs
3790+
3791+
model.forward = patched_forward
3792+
3793+
super().__init__(config, model, model_kwargs)

optimum/exporters/openvino/stateful.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,95 @@ def ensure_stateful_is_available(warn=True):
190190
return True
191191

192192

193+
_ENCODER_DECODER_TASKS_WITH_PAST = (
194+
"automatic-speech-recognition",
195+
"text2text-generation",
196+
)
197+
198+
_DECODER_TASKS_WITH_PAST = ("text-generation",)
199+
200+
193201
def ensure_export_task_support_stateful(task: str):
194202
from optimum.exporters import TasksManager
195203

196204
task = TasksManager.map_from_synonym(task)
197-
return task in ["text-generation-with-past"]
205+
206+
is_stateful = (
207+
task.endswith("-with-past")
208+
and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST
209+
)
210+
return is_stateful
198211

199212

200213
def ensure_model_type_support_stateful(model_type: str):
201214
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS
202215

203216

204-
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"):
217+
def remove_parameters_by_names(model: ov.Model, names: list):
218+
parameters = [model.input(name).get_node() for name in names]
219+
for p in parameters:
220+
model.remove_parameter(p)
221+
222+
223+
def get_input_nodes(node):
224+
return [input.get_node() for input in node.input_values()]
225+
226+
227+
def find_dependent_nodes(model: ov.Model, sources: list):
228+
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
229+
result = set(sources)
230+
for node in model.get_ordered_ops():
231+
input_nodes = set(get_input_nodes(node))
232+
if input_nodes & result:
233+
result.add(node)
234+
return result
235+
236+
237+
def get_read_value_ops(model: ov.Model):
238+
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]
239+
240+
241+
def get_shape_of_ops(model: ov.Model):
242+
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]
243+
244+
245+
def get_consumer_nodes(node):
246+
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
247+
return {input.get_node() for input in consumer_inputs}
248+
249+
250+
def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
251+
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
252+
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
253+
other_nodes = find_dependent_nodes(model, other_inputs)
254+
source_dependent_nodes = find_dependent_nodes(model, sources)
255+
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
256+
nodes = source_dependent_nodes - other_nodes
257+
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
258+
return edge_nodes
259+
260+
261+
def insert_state_for_nodes(model: ov.Model, nodes):
262+
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
263+
outputs = sum((node.outputs() for node in nodes), [])
264+
for output in outputs:
265+
consumers = output.get_target_inputs()
266+
# FIXME: get_any_name is not reliable as tensor may not have any names
267+
variable_id = output.get_any_name()
268+
read_value = ov.runtime.opset13.read_value(output, variable_id)
269+
for consumer in consumers:
270+
consumer.replace_source_output(read_value.output(0))
271+
assign = ov.runtime.opset13.assign(read_value, variable_id)
272+
model.add_sinks([assign])
273+
274+
275+
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
276+
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
277+
return patch_stateful_encoder_decoder(config, ov_model)
278+
return patch_stateful_decoder(config, ov_model)
279+
280+
281+
def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
205282
"""
206283
Apply stateful transformation to model to hide key values inputs inside model.
207284
Select transformation parameters based on model architecture
@@ -236,3 +313,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name
236313
make_stateful(
237314
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
238315
)
316+
317+
318+
def patch_stateful_encoder_decoder(config, ov_model):
319+
encoder_key_value_input_names = [
320+
key.get_any_name()
321+
for key in ov_model.inputs
322+
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
323+
]
324+
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
325+
patch_stateful_decoder(config, ov_model)
326+
insert_state_for_nodes(
327+
ov_model,
328+
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
329+
)

0 commit comments

Comments
 (0)