Skip to content

Commit 75c653d

Browse files
committed
review comments and kv cache compression disable in fp
1 parent 1066e01 commit 75c653d

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

optimum/exporters/openvino/convert.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from optimum.exporters import TasksManager
3030
from optimum.exporters.utils import (
3131
DECODER_NAME,
32-
DECODER_WITH_PAST_NAME,
3332
ENCODER_NAME,
3433
_get_submodels_for_export_encoder_decoder,
3534
get_diffusion_models_for_export,
@@ -48,7 +47,6 @@
4847
_transformers_version,
4948
compare_versions,
5049
is_diffusers_version,
51-
is_openvino_version,
5250
is_openvino_tokenizers_version,
5351
is_tokenizers_version,
5452
is_transformers_version,
@@ -110,10 +108,13 @@ def _set_runtime_options(
110108
"diffusers" in library_name
111109
or "text-generation" in task
112110
or ("image-text-to-text" in task and model_name == "language_model")
111+
or getattr(sub_export_config, "stateful", False)
113112
):
114113
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
115114
if not quantized_model and (
116-
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
115+
"text-generation" in task
116+
or ("image-text-to-text" in task and model_name == "language_model")
117+
or getattr(sub_export_config, "stateful", False)
117118
):
118119
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"
119120

@@ -643,7 +644,7 @@ def export_from_model(
643644
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
644645
model_type = getattr(getattr(model, "config", {}), "model_type", "")
645646
stateful = stateful and (
646-
ensure_export_task_support_stateful(task, is_encoder_decoder) or ensure_model_type_support_stateful(model_type)
647+
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
647648
)
648649

649650
if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
@@ -1251,17 +1252,16 @@ def _get_encoder_decoder_stateful_models_for_export(
12511252
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
12521253
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")
12531254

1254-
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True)
1255+
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)
12551256

12561257
encoder_export_config = export_config.with_behavior("encoder")
12571258
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)
12581259

12591260
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
12601261

12611262
decoder_export_config_with_past.stateful = True
1262-
decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME)
12631263
models_for_export[DECODER_NAME] = (
1264-
decoder_with_past_model,
1264+
models_for_export[DECODER_NAME],
12651265
decoder_export_config_with_past,
12661266
)
12671267
return None, models_for_export

optimum/exporters/openvino/model_patcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3700,4 +3700,4 @@ def patched_forward(*args, **kwargs):
37003700

37013701
model.forward = patched_forward
37023702

3703-
super().__init__(config, model, model_kwargs)
3703+
super().__init__(config, model, model_kwargs)

optimum/intel/openvino/modeling_seq2seq.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ def forward(
374374
# Decode
375375
if past_key_values is None or self.decoder_with_past is None:
376376
decoder_outputs = self.decoder(
377-
input_ids=decoder_input_ids[:, -1:]
378-
if past_key_values is not None and self.use_cache
379-
else decoder_input_ids,
377+
input_ids=(
378+
decoder_input_ids[:, -1:] if past_key_values is not None and self.use_cache else decoder_input_ids
379+
),
380380
past_key_values=past_key_values,
381381
encoder_hidden_states=encoder_outputs.last_hidden_state,
382382
encoder_attention_mask=attention_mask,

0 commit comments

Comments
 (0)