@@ -240,6 +240,9 @@ def main_export(
240
240
loading_kwargs = model_loading_kwargs or {}
241
241
if variant is not None :
242
242
loading_kwargs ["variant" ] = variant
243
+ dtype = loading_kwargs .get ("torch_dtype" , None )
244
+ if isinstance (dtype , str ):
245
+ dtype = getattr (torch , dtype ) if dtype != "auto" else dtype
243
246
if library_name == "transformers" :
244
247
config = AutoConfig .from_pretrained (
245
248
model_name_or_path ,
@@ -302,9 +305,8 @@ def main_export(
302
305
"Please provide custom export config if you want load model with remote code."
303
306
)
304
307
trust_remote_code = False
305
- dtype = loading_kwargs .get ("torch_dtype" )
306
- if isinstance (dtype , str ):
307
- dtype = getattr (config , "torch_dtype" ) if dtype == "auto" else getattr (torch , dtype )
308
+ if dtype == "auto" :
309
+ dtype = getattr (config , "torch_dtype" )
308
310
309
311
if (
310
312
dtype is None
@@ -351,12 +353,7 @@ class StoreAttr(object):
351
353
GPTQQuantizer .post_init_model = post_init_model
352
354
elif library_name == "diffusers" and is_openvino_version (">=" , "2024.6" ):
353
355
_loading_kwargs = {} if variant is None else {"variant" : variant }
354
- dtype = loading_kwargs .pop ("torch_dtype" , None )
355
- if isinstance (dtype , str ):
356
- dtype = None if dtype == "auto" else getattr (torch , dtype )
357
- if ov_config is not None and ov_config .dtype in {"fp16" , "fp32" }:
358
- dtype = torch .float16 if ov_config .dtype == "fp16" else torch .float32
359
- if dtype is None :
356
+ if dtype == "auto" or dtype is None :
360
357
dtype = deduce_diffusers_dtype (
361
358
model_name_or_path ,
362
359
revision = revision ,
@@ -367,9 +364,17 @@ class StoreAttr(object):
367
364
trust_remote_code = trust_remote_code ,
368
365
** _loading_kwargs ,
369
366
)
367
+ if (
368
+ dtype in {torch .bfloat16 , torch .float16 }
369
+ and ov_config is not None
370
+ and ov_config .dtype in {"fp16" , "fp32" }
371
+ ):
372
+ dtype = torch .float16 if ov_config .dtype == "fp16" else torch .float32
370
373
if dtype in [torch .float16 , torch .bfloat16 ]:
371
374
loading_kwargs ["torch_dtype" ] = dtype
372
375
patch_16bit = True
376
+ if loading_kwargs .get ("torch_dtype" ) == "auto" :
377
+ loading_kwargs ["torch_dtype" ] = dtype
373
378
374
379
try :
375
380
if library_name == "open_clip" :
0 commit comments