Skip to content

Commit 9219632

Browse files
committed
respect from_onnx
1 parent 26e08d1 commit 9219632

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

optimum/intel/openvino/modeling_base_seq2seq.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,11 @@ def _from_pretrained(
206206
if not compile_only:
207207
encoder = cls.load_model(os.path.join(model_id, encoder_file_name), quantization_config)
208208
decoder = cls.load_model(os.path.join(model_id, decoder_file_name), quantization_config)
209-
if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)):
209+
if (
210+
use_cache
211+
and not model_has_state(decoder)
212+
and os.path.exists(os.path.join(model_id, decoder_with_past_file_name))
213+
):
210214
decoder_with_past = cls.load_model(
211215
os.path.join(model_id, decoder_with_past_file_name), quantization_config
212216
)
@@ -223,7 +227,11 @@ def _from_pretrained(
223227
kwargs.get("ov_config"),
224228
model_save_dir,
225229
)
226-
if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)):
230+
if (
231+
use_cache
232+
and not model_has_state(decoder)
233+
and os.path.exists(os.path.join(model_id, decoder_with_past_file_name))
234+
):
227235
decoder_with_past = cls._compile_model(
228236
os.path.join(model_id, decoder_with_past_file_name),
229237
kwargs.get("device", "CPU"),
@@ -259,8 +267,11 @@ def _from_pretrained(
259267
decoder = cls.load_model(file_names["decoder"], quantization_config)
260268
if use_cache and not model_has_state(decoder):
261269
model_file_names["decoder_with_past"] = decoder_with_past_file_name
262-
model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin")
263-
for name in ["decoder_with_past", "decoder_with_past_bin"]:
270+
with_past_files = ["decoder_with_past"]
271+
if not from_onnx:
272+
with_past_files.append("decoder_with_past_bin")
273+
model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin")
274+
for name in with_past_files:
264275
model_cache_path = hf_hub_download(
265276
repo_id=model_id,
266277
filename=model_file_names[name],
@@ -282,8 +293,11 @@ def _from_pretrained(
282293
)
283294
if use_cache and not model_has_state(decoder):
284295
model_file_names["decoder_with_past"] = decoder_with_past_file_name
285-
model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin")
286-
for name in ["decoder_with_past", "decoder_with_past_bin"]:
296+
with_past_files = ["decoder_with_past"]
297+
if not from_onnx:
298+
with_past_files.append("decoder_with_past_bin")
299+
model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin")
300+
for name in with_past_files:
287301
model_cache_path = hf_hub_download(
288302
repo_id=model_id,
289303
filename=model_file_names[name],

0 commit comments

Comments
 (0)