Skip to content

Commit bb18f27

Browse files
committed
merge decoder and decoder with past to stateful for seq2seq
1 parent a76be08 commit bb18f27

File tree

6 files changed

+423
-55
lines changed

6 files changed

+423
-55
lines changed

optimum/exporters/openvino/convert.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
)
3333
from optimum.exporters.utils import (
3434
get_diffusion_models_for_export,
35+
DECODER_NAME,
36+
DECODER_WITH_PAST_NAME,
37+
ENCODER_NAME,
38+
_get_submodels_for_export_encoder_decoder,
3539
)
3640
from optimum.intel.utils.import_utils import (
3741
_diffusers_version,
@@ -43,6 +47,7 @@
4347
_torch_version,
4448
_transformers_version,
4549
compare_versions,
50+
is_openvino_version,
4651
is_openvino_tokenizers_version,
4752
is_tokenizers_version,
4853
is_transformers_version,
@@ -624,10 +629,14 @@ def export_from_model(
624629

625630
logger.info(f"Automatic task detection to: {task}.")
626631

632+
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
633+
model_type = getattr(getattr(model, "config", {}), "model_type", "")
627634
stateful = stateful and (
628-
ensure_export_task_support_stateful(task)
629-
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
635+
ensure_export_task_support_stateful(task, is_encoder_decoder) or ensure_model_type_support_stateful(model_type)
630636
)
637+
638+
if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
639+
stateful = False
631640
# TODO: support onnx_config.py in the model repo
632641
if custom_architecture and custom_export_configs is None:
633642
raise ValueError(
@@ -666,6 +675,11 @@ def export_from_model(
666675
if library_name == "diffusers":
667676
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
668677
stateful_submodels = False
678+
elif stateful and is_encoder_decoder and not custom_architecture:
679+
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
680+
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
681+
)
682+
stateful_submodels = [False, True]
669683
else:
670684
logging.disable(logging.INFO)
671685
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
@@ -1188,3 +1202,43 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
11881202
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
11891203

11901204
return models_for_export
1205+
1206+
1207+
def _get_encoder_decoder_stateful_models_for_export(
1208+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1209+
task: str,
1210+
_variant: str,
1211+
library_name: str,
1212+
int_dtype: str = "int64",
1213+
float_dtype: str = "fp32",
1214+
preprocessors: Optional[List[Any]] = None,
1215+
):
1216+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1217+
model=model, exporter="openvino", task=task, library_name=library_name
1218+
)
1219+
export_config = export_config_constructor(
1220+
model.config,
1221+
int_dtype=int_dtype,
1222+
float_dtype=float_dtype,
1223+
preprocessors=preprocessors,
1224+
legacy=False,
1225+
)
1226+
1227+
export_config.variant = _variant
1228+
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
1229+
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")
1230+
1231+
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True)
1232+
1233+
encoder_export_config = export_config.with_behavior("encoder")
1234+
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)
1235+
1236+
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
1237+
1238+
decoder_export_config_with_past.stateful = True
1239+
decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME)
1240+
models_for_export[DECODER_NAME] = (
1241+
decoder_with_past_model,
1242+
decoder_export_config_with_past,
1243+
)
1244+
return None, models_for_export

optimum/exporters/openvino/model_configs.py

+56
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

23+
from optimum.exporters.onnx.base import ConfigBehavior
2324
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2425
from optimum.exporters.onnx.model_configs import (
2526
CLIPOnnxConfig,
@@ -36,8 +37,10 @@
3637
MistralOnnxConfig,
3738
MPTOnnxConfig,
3839
PhiOnnxConfig,
40+
T5OnnxConfig,
3941
UNetOnnxConfig,
4042
VisionOnnxConfig,
43+
WhisperOnnxConfig,
4144
)
4245
from optimum.exporters.onnx.model_patcher import ModelPatcher
4346
from optimum.exporters.tasks import TasksManager
@@ -90,6 +93,7 @@
9093
Phi3VisionImageEmbeddingsPatcher,
9194
QwenModelPatcher,
9295
RotaryEmbPatcher,
96+
StatefulSeq2SeqDecoderPatcher,
9397
UpdateCausalMaskModelPatcher,
9498
XverseModelPatcher,
9599
)
@@ -2260,3 +2264,55 @@ def patch_model_for_export(
22602264
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
22612265
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
22622266
return super().patch_model_for_export(model, model_kwargs)
2267+
2268+
2269+
@register_in_tasks_manager(
2270+
"whisper",
2271+
*[
2272+
"feature-extraction",
2273+
"feature-extraction-with-past",
2274+
"audio-classification",
2275+
"automatic-speech-recognition",
2276+
"automatic-speech-recognition-with-past",
2277+
],
2278+
library_name="transformers",
2279+
)
2280+
class WhisperOpenVINOConfig(WhisperOnnxConfig):
2281+
def patch_model_for_export(
2282+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2283+
) -> ModelPatcher:
2284+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2285+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2286+
return super().patch_model_for_export(model, model_kwargs)
2287+
2288+
2289+
@register_in_tasks_manager(
2290+
"t5",
2291+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2292+
library_name="transformers",
2293+
)
2294+
class T5OpenVINOConfig(T5OnnxConfig):
2295+
def patch_model_for_export(
2296+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2297+
) -> ModelPatcher:
2298+
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
2299+
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
2300+
return super().patch_model_for_export(model, model_kwargs)
2301+
2302+
2303+
@register_in_tasks_manager(
2304+
"mt5",
2305+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2306+
library_name="transformers",
2307+
)
2308+
class MT5OpenVINOConfig(T5OpenVINOConfig):
2309+
pass
2310+
2311+
2312+
@register_in_tasks_manager(
2313+
"longt5",
2314+
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
2315+
library_name="transformers",
2316+
)
2317+
class LongT5OpenVINOConfig(T5OpenVINOConfig):
2318+
pass

optimum/exporters/openvino/model_patcher.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2525
from transformers.utils import is_tf_available
2626

27-
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
27+
from optimum.exporters.onnx.model_patcher import (
28+
DecoderModelPatcher,
29+
ModelPatcher,
30+
override_arguments,
31+
Seq2SeqModelPatcher,
32+
)
2833
from optimum.intel.utils.import_utils import (
2934
_openvino_version,
3035
_torch_version,
@@ -3378,3 +3383,49 @@ def __exit__(self, exc_type, exc_value, traceback):
33783383
super().__exit__(exc_type, exc_value, traceback)
33793384
for block in self._model.model.layers:
33803385
block.self_attn.forward = block.self_attn._orig_forward
3386+
3387+
3388+
class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
3389+
def __init__(
3390+
self,
3391+
config: "OnnxConfig",
3392+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3393+
model_kwargs: Optional[Dict[str, Any]] = None,
3394+
):
3395+
model.__orig_forward = model.forward
3396+
3397+
@functools.wraps(model.__orig_forward)
3398+
def patched_forward(*args, **kwargs):
3399+
from transformers.cache_utils import EncoderDecoderCache
3400+
3401+
signature = inspect.signature(self.orig_forward)
3402+
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
3403+
3404+
return_legacy_cache = False
3405+
pkv_in_args = False
3406+
legacy_pkv = None
3407+
if "past_key_values" in kwargs:
3408+
legacy_pkv = kwargs.pop("past_key_values", None)
3409+
sign_names = list(signature.parameters.keys())
3410+
pkv_argument_index = sign_names.index("past_key_values")
3411+
if legacy_pkv is None and len(args) > pkv_argument_index:
3412+
legacy_pkv = args[pkv_argument_index]
3413+
pkv_in_args = True
3414+
if legacy_pkv is not None:
3415+
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
3416+
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
3417+
return_legacy_cache = True
3418+
if not pkv_in_args:
3419+
kwargs["past_key_values"] = pkv
3420+
else:
3421+
args[pkv_argument_index] = pkv
3422+
3423+
outputs = model.__orig_forward(*args, **kwargs)
3424+
if return_legacy_cache:
3425+
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
3426+
3427+
return outputs
3428+
3429+
model.forward = patched_forward
3430+
3431+
super().__init__(config, model, model_kwargs)

optimum/exporters/openvino/stateful.py

+92-3
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,93 @@ def ensure_stateful_is_available(warn=True):
190190
return True
191191

192192

193-
def ensure_export_task_support_stateful(task: str):
193+
_ENCODER_DECODER_TASKS_WITH_PAST = (
194+
"automatic-speech-recognition",
195+
"text2text-generation",
196+
)
197+
198+
199+
def ensure_export_task_support_stateful(task: str, is_encoder_decoder: bool = False):
194200
from optimum.exporters import TasksManager
195201

196202
task = TasksManager.map_from_synonym(task)
197-
return task in ["text-generation-with-past"]
203+
204+
if not is_encoder_decoder:
205+
return task in ["text-generation-with-past"]
206+
207+
is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST
208+
return is_stateful
198209

199210

200211
def ensure_model_type_support_stateful(model_type: str):
201212
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS
202213

203214

204-
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"):
215+
def remove_parameters_by_names(model: ov.Model, names: list):
216+
parameters = [model.input(name).get_node() for name in names]
217+
for p in parameters:
218+
model.remove_parameter(p)
219+
220+
221+
def get_input_nodes(node):
222+
return [input.get_node() for input in node.input_values()]
223+
224+
225+
def find_dependent_nodes(model: ov.Model, sources: list):
226+
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
227+
result = set(sources)
228+
for node in model.get_ordered_ops():
229+
input_nodes = set(get_input_nodes(node))
230+
if input_nodes & result:
231+
result.add(node)
232+
return result
233+
234+
235+
def get_read_value_ops(model: ov.Model):
236+
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]
237+
238+
239+
def get_shape_of_ops(model: ov.Model):
240+
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]
241+
242+
243+
def get_consumer_nodes(node):
244+
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
245+
return set(input.get_node() for input in consumer_inputs)
246+
247+
248+
def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
249+
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
250+
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
251+
other_nodes = find_dependent_nodes(model, other_inputs)
252+
source_dependent_nodes = find_dependent_nodes(model, sources)
253+
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
254+
nodes = source_dependent_nodes - other_nodes
255+
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
256+
return edge_nodes
257+
258+
259+
def insert_state_for_nodes(model: ov.Model, nodes):
260+
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
261+
outputs = sum((node.outputs() for node in nodes), [])
262+
for output in outputs:
263+
consumers = output.get_target_inputs()
264+
# FIXME: get_any_name is not reliable as tensor may not have any names
265+
variable_id = output.get_any_name()
266+
read_value = ov.runtime.opset13.read_value(output, variable_id)
267+
for consumer in consumers:
268+
consumer.replace_source_output(read_value.output(0))
269+
assign = ov.runtime.opset13.assign(read_value, variable_id)
270+
model.add_sinks([assign])
271+
272+
273+
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
274+
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
275+
return patch_stateful_encoder_decoder(config, ov_model)
276+
return patch_stateful_decoder(config, ov_model)
277+
278+
279+
def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
205280
"""
206281
Apply stateful transformation to model to hide key values inputs inside model.
207282
Select transformation parameters based on model architecture
@@ -236,3 +311,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name
236311
make_stateful(
237312
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
238313
)
314+
315+
316+
def patch_stateful_encoder_decoder(config, ov_model):
317+
encoder_key_value_input_names = [
318+
key.get_any_name()
319+
for key in ov_model.inputs
320+
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
321+
]
322+
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
323+
patch_stateful_decoder(config, ov_model)
324+
insert_state_for_nodes(
325+
ov_model,
326+
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
327+
)

0 commit comments

Comments
 (0)