Commit 51b4e0c 1 parent 6cceb30 commit 51b4e0c Copy full SHA for 51b4e0c
File tree 2 files changed +25
-11
lines changed
2 files changed +25
-11
lines changed Original file line number Diff line number Diff line change @@ -351,16 +351,22 @@ class StoreAttr(object):
351
351
GPTQQuantizer .post_init_model = post_init_model
352
352
elif library_name == "diffusers" and is_openvino_version (">=" , "2024.6" ):
353
353
_loading_kwargs = {} if variant is None else {"variant" : variant }
354
- dtype = deduce_diffusers_dtype (
355
- model_name_or_path ,
356
- revision = revision ,
357
- cache_dir = cache_dir ,
358
- token = token ,
359
- local_files_only = local_files_only ,
360
- force_download = force_download ,
361
- trust_remote_code = trust_remote_code ,
362
- ** _loading_kwargs ,
363
- )
354
+ dtype = loading_kwargs .pop ("torch_dtype" , None )
355
+ if isinstance (dtype , str ):
356
+ dtype = None if dtype == "auto" else getattr (torch , dtype )
357
+ if ov_config is not None and ov_config .dtype in {"fp16" , "fp32" }:
358
+ dtype = torch .float16 if ov_config .dtype == "fp16" else torch .float32
359
+ if dtype is None :
360
+ dtype = deduce_diffusers_dtype (
361
+ model_name_or_path ,
362
+ revision = revision ,
363
+ cache_dir = cache_dir ,
364
+ token = token ,
365
+ local_files_only = local_files_only ,
366
+ force_download = force_download ,
367
+ trust_remote_code = trust_remote_code ,
368
+ ** _loading_kwargs ,
369
+ )
364
370
if dtype in [torch .float16 , torch .bfloat16 ]:
365
371
loading_kwargs ["torch_dtype" ] = dtype
366
372
patch_16bit = True
Original file line number Diff line number Diff line change @@ -593,7 +593,14 @@ def _from_transformers(
593
593
if load_in_8bit is None and not quantization_config :
594
594
ov_config = None
595
595
else :
596
- ov_config = OVConfig (dtype = "fp32" )
596
+ ov_config = OVConfig (dtype = "auto" )
597
+
598
+ torch_dtype = kwargs .pop ("torch_dtype" , None )
599
+
600
+ model_loading_kwargs = {}
601
+
602
+ if torch_dtype is not None :
603
+ model_loading_kwargs ["torch_dtype" ] = torch_dtype
597
604
598
605
model_save_dir = TemporaryDirectory ()
599
606
model_save_path = Path (model_save_dir .name )
@@ -613,6 +620,7 @@ def _from_transformers(
613
620
ov_config = ov_config ,
614
621
library_name = cls ._library_name ,
615
622
variant = variant ,
623
+ model_loading_kwargs = model_loading_kwargs ,
616
624
)
617
625
618
626
return cls ._from_pretrained (
You can’t perform that action at this time.
0 commit comments