Skip to content

Commit 51b4e0c

Browse files
committed
align loading dtype logic for diffusers with other models
1 parent 6cceb30 commit 51b4e0c

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

optimum/exporters/openvino/__main__.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,22 @@ class StoreAttr(object):
351351
GPTQQuantizer.post_init_model = post_init_model
352352
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
353353
_loading_kwargs = {} if variant is None else {"variant": variant}
354-
dtype = deduce_diffusers_dtype(
355-
model_name_or_path,
356-
revision=revision,
357-
cache_dir=cache_dir,
358-
token=token,
359-
local_files_only=local_files_only,
360-
force_download=force_download,
361-
trust_remote_code=trust_remote_code,
362-
**_loading_kwargs,
363-
)
354+
dtype = loading_kwargs.pop("torch_dtype", None)
355+
if isinstance(dtype, str):
356+
dtype = None if dtype == "auto" else getattr(torch, dtype)
357+
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
358+
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
359+
if dtype is None:
360+
dtype = deduce_diffusers_dtype(
361+
model_name_or_path,
362+
revision=revision,
363+
cache_dir=cache_dir,
364+
token=token,
365+
local_files_only=local_files_only,
366+
force_download=force_download,
367+
trust_remote_code=trust_remote_code,
368+
**_loading_kwargs,
369+
)
364370
if dtype in [torch.float16, torch.bfloat16]:
365371
loading_kwargs["torch_dtype"] = dtype
366372
patch_16bit = True

optimum/intel/openvino/modeling_diffusion.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,14 @@ def _from_transformers(
593593
if load_in_8bit is None and not quantization_config:
594594
ov_config = None
595595
else:
596-
ov_config = OVConfig(dtype="fp32")
596+
ov_config = OVConfig(dtype="auto")
597+
598+
torch_dtype = kwargs.pop("torch_dtype", None)
599+
600+
model_loading_kwargs = {}
601+
602+
if torch_dtype is not None:
603+
model_loading_kwargs["torch_dtype"] = torch_dtype
597604

598605
model_save_dir = TemporaryDirectory()
599606
model_save_path = Path(model_save_dir.name)
@@ -613,6 +620,7 @@ def _from_transformers(
613620
ov_config=ov_config,
614621
library_name=cls._library_name,
615622
variant=variant,
623+
model_loading_kwargs=model_loading_kwargs,
616624
)
617625

618626
return cls._from_pretrained(

0 commit comments

Comments
 (0)