@@ -206,7 +206,11 @@ def _from_pretrained(
206
206
if not compile_only :
207
207
encoder = cls .load_model (os .path .join (model_id , encoder_file_name ), quantization_config )
208
208
decoder = cls .load_model (os .path .join (model_id , decoder_file_name ), quantization_config )
209
- if use_cache and os .path .exists (os .path .join (model_id , decoder_with_past_file_name )):
209
+ if (
210
+ use_cache
211
+ and not model_has_state (decoder )
212
+ and os .path .exists (os .path .join (model_id , decoder_with_past_file_name ))
213
+ ):
210
214
decoder_with_past = cls .load_model (
211
215
os .path .join (model_id , decoder_with_past_file_name ), quantization_config
212
216
)
@@ -223,7 +227,11 @@ def _from_pretrained(
223
227
kwargs .get ("ov_config" ),
224
228
model_save_dir ,
225
229
)
226
- if use_cache and os .path .exists (os .path .join (model_id , decoder_with_past_file_name )):
230
+ if (
231
+ use_cache
232
+ and not model_has_state (decoder )
233
+ and os .path .exists (os .path .join (model_id , decoder_with_past_file_name ))
234
+ ):
227
235
decoder_with_past = cls ._compile_model (
228
236
os .path .join (model_id , decoder_with_past_file_name ),
229
237
kwargs .get ("device" , "CPU" ),
@@ -259,8 +267,11 @@ def _from_pretrained(
259
267
decoder = cls .load_model (file_names ["decoder" ], quantization_config )
260
268
if use_cache and not model_has_state (decoder ):
261
269
model_file_names ["decoder_with_past" ] = decoder_with_past_file_name
262
- model_file_names ["decoder_with_past_bin" ] = decoder_with_past_file_name .replace (".xml" , ".bin" )
263
- for name in ["decoder_with_past" , "decoder_with_past_bin" ]:
270
+ with_past_files = ["decoder_with_past" ]
271
+ if not from_onnx :
272
+ with_past_files .append ("decoder_with_past_bin" )
273
+ model_file_names ["decoder_with_past_bin" ] = decoder_with_past_file_name .replace (".xml" , ".bin" )
274
+ for name in with_past_files :
264
275
model_cache_path = hf_hub_download (
265
276
repo_id = model_id ,
266
277
filename = model_file_names [name ],
@@ -282,8 +293,11 @@ def _from_pretrained(
282
293
)
283
294
if use_cache and not model_has_state (decoder ):
284
295
model_file_names ["decoder_with_past" ] = decoder_with_past_file_name
285
- model_file_names ["decoder_with_past_bin" ] = decoder_with_past_file_name .replace (".xml" , ".bin" )
286
- for name in ["decoder_with_past" , "decoder_with_past_bin" ]:
296
+ with_past_files = ["decoder_with_past" ]
297
+ if not from_onnx :
298
+ with_past_files .append ("decoder_with_past_bin" )
299
+ model_file_names ["decoder_with_past_bin" ] = decoder_with_past_file_name .replace (".xml" , ".bin" )
300
+ for name in with_past_files :
287
301
model_cache_path = hf_hub_download (
288
302
repo_id = model_id ,
289
303
filename = model_file_names [name ],
0 commit comments