@@ -197,7 +197,6 @@ def _from_pretrained(
197
197
model_id = str (model_id )
198
198
sub_models_to_load , _ , _ = cls .extract_init_dict (config )
199
199
sub_models_names = set (sub_models_to_load .keys ()).intersection ({"feature_extractor" , "tokenizer" , "scheduler" })
200
- sub_models = {}
201
200
202
201
if not os .path .isdir (model_id ):
203
202
patterns = set (config .keys ())
@@ -231,16 +230,19 @@ def _from_pretrained(
231
230
new_model_save_dir = Path (model_id )
232
231
233
232
for name in sub_models_names :
233
+ # Check if the subcomponent needs to be loaded
234
+ if kwargs .get (name , None ) is not None :
235
+ continue
234
236
library_name , library_classes = sub_models_to_load [name ]
235
237
if library_classes is not None :
236
238
library = importlib .import_module (library_name )
237
239
class_obj = getattr (library , library_classes )
238
240
load_method = getattr (class_obj , "from_pretrained" )
239
241
# Check if the module is in a subdirectory
240
242
if (new_model_save_dir / name ).is_dir ():
241
- sub_models [name ] = load_method (new_model_save_dir / name )
243
+ kwargs [name ] = load_method (new_model_save_dir / name )
242
244
else :
243
- sub_models [name ] = load_method (new_model_save_dir )
245
+ kwargs [name ] = load_method (new_model_save_dir )
244
246
245
247
vae_decoder = cls .load_model (
246
248
new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name
@@ -260,9 +262,9 @@ def _from_pretrained(
260
262
text_encoder = text_encoder ,
261
263
unet = unet ,
262
264
config = config ,
263
- tokenizer = sub_models [ "tokenizer" ] ,
264
- scheduler = sub_models [ "scheduler" ] ,
265
- feature_extractor = sub_models .pop ("feature_extractor" , None ),
265
+ tokenizer = kwargs . pop ( "tokenizer" ) ,
266
+ scheduler = kwargs . pop ( "scheduler" ) ,
267
+ feature_extractor = kwargs .pop ("feature_extractor" , None ),
266
268
vae_encoder = vae_encoder ,
267
269
model_save_dir = model_save_dir ,
268
270
** kwargs ,
@@ -279,6 +281,9 @@ def _from_transformers(
279
281
cache_dir : Optional [str ] = None ,
280
282
local_files_only : bool = False ,
281
283
task : Optional [str ] = None ,
284
+ tokenizer : "CLIPTokenizer" = None ,
285
+ scheduler : Union ["DDIMScheduler" , "PNDMScheduler" , "LMSDiscreteScheduler" ] = None ,
286
+ feature_extractor : Optional ["CLIPFeatureExtractor" ] = None ,
282
287
** kwargs ,
283
288
):
284
289
if task is None :
@@ -303,13 +308,7 @@ def _from_transformers(
303
308
os .path .join (DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER , ONNX_WEIGHTS_NAME ),
304
309
]
305
310
models_and_onnx_configs = get_stable_diffusion_models_for_export (model )
306
-
307
311
model .save_config (save_dir_path )
308
- model .tokenizer .save_pretrained (save_dir_path .joinpath ("tokenizer" ))
309
- model .scheduler .save_pretrained (save_dir_path .joinpath ("scheduler" ))
310
- if model .feature_extractor is not None :
311
- model .feature_extractor .save_pretrained (save_dir_path .joinpath ("feature_extractor" ))
312
-
313
312
export_models (
314
313
models_and_onnx_configs = models_and_onnx_configs ,
315
314
output_dir = save_dir_path ,
@@ -325,7 +324,10 @@ def _from_transformers(
325
324
force_download = force_download ,
326
325
cache_dir = cache_dir ,
327
326
local_files_only = local_files_only ,
328
- model_save_dir = save_dir , # important
327
+ model_save_dir = save_dir ,
328
+ tokenizer = tokenizer or model .tokenizer ,
329
+ scheduler = scheduler or model .scheduler ,
330
+ feature_extractor = feature_extractor or model .feature_extractor ,
329
331
** kwargs ,
330
332
)
331
333
0 commit comments