@@ -240,6 +240,9 @@ def main_export(
240
240
loading_kwargs = model_loading_kwargs or {}
241
241
if variant is not None :
242
242
loading_kwargs ["variant" ] = variant
243
+ dtype = loading_kwargs .get ("torch_dtype" , None )
244
+ if isinstance (dtype , str ):
245
+ dtype = getattr (torch , dtype ) if dtype != "auto" else dtype
243
246
if library_name == "transformers" :
244
247
config = AutoConfig .from_pretrained (
245
248
model_name_or_path ,
@@ -302,9 +305,8 @@ def main_export(
302
305
"Please provide custom export config if you want load model with remote code."
303
306
)
304
307
trust_remote_code = False
305
- dtype = loading_kwargs .get ("torch_dtype" )
306
- if isinstance (dtype , str ):
307
- dtype = getattr (config , "torch_dtype" ) if dtype == "auto" else getattr (torch , dtype )
308
+ if dtype == "auto" :
309
+ dtype = getattr (config , "torch_dtype" )
308
310
309
311
if (
310
312
dtype is None
@@ -351,19 +353,28 @@ class StoreAttr(object):
351
353
GPTQQuantizer .post_init_model = post_init_model
352
354
elif library_name == "diffusers" and is_openvino_version (">=" , "2024.6" ):
353
355
_loading_kwargs = {} if variant is None else {"variant" : variant }
354
- dtype = deduce_diffusers_dtype (
355
- model_name_or_path ,
356
- revision = revision ,
357
- cache_dir = cache_dir ,
358
- token = token ,
359
- local_files_only = local_files_only ,
360
- force_download = force_download ,
361
- trust_remote_code = trust_remote_code ,
362
- ** _loading_kwargs ,
363
- )
356
+ if dtype == "auto" or dtype is None :
357
+ dtype = deduce_diffusers_dtype (
358
+ model_name_or_path ,
359
+ revision = revision ,
360
+ cache_dir = cache_dir ,
361
+ token = token ,
362
+ local_files_only = local_files_only ,
363
+ force_download = force_download ,
364
+ trust_remote_code = trust_remote_code ,
365
+ ** _loading_kwargs ,
366
+ )
367
+ if (
368
+ dtype in {torch .bfloat16 , torch .float16 }
369
+ and ov_config is not None
370
+ and ov_config .dtype in {"fp16" , "fp32" }
371
+ ):
372
+ dtype = torch .float16 if ov_config .dtype == "fp16" else torch .float32
364
373
if dtype in [torch .float16 , torch .bfloat16 ]:
365
374
loading_kwargs ["torch_dtype" ] = dtype
366
375
patch_16bit = True
376
+ if loading_kwargs .get ("torch_dtype" ) == "auto" :
377
+ loading_kwargs ["torch_dtype" ] = dtype
367
378
368
379
try :
369
380
if library_name == "open_clip" :
0 commit comments