Skip to content

Commit ba9aaaf

Browse files
authored
Align loading dtype logic for diffusers with other models (huggingface#1187)
* align loading dtype logic for diffusers with other models * rework logic for dtype handling
1 parent 6cceb30 commit ba9aaaf

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

optimum/exporters/openvino/__main__.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def main_export(
240240
loading_kwargs = model_loading_kwargs or {}
241241
if variant is not None:
242242
loading_kwargs["variant"] = variant
243+
dtype = loading_kwargs.get("torch_dtype", None)
244+
if isinstance(dtype, str):
245+
dtype = getattr(torch, dtype) if dtype != "auto" else dtype
243246
if library_name == "transformers":
244247
config = AutoConfig.from_pretrained(
245248
model_name_or_path,
@@ -302,9 +305,8 @@ def main_export(
302305
"Please provide custom export config if you want load model with remote code."
303306
)
304307
trust_remote_code = False
305-
dtype = loading_kwargs.get("torch_dtype")
306-
if isinstance(dtype, str):
307-
dtype = getattr(config, "torch_dtype") if dtype == "auto" else getattr(torch, dtype)
308+
if dtype == "auto":
309+
dtype = getattr(config, "torch_dtype")
308310

309311
if (
310312
dtype is None
@@ -351,19 +353,28 @@ class StoreAttr(object):
351353
GPTQQuantizer.post_init_model = post_init_model
352354
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
353355
_loading_kwargs = {} if variant is None else {"variant": variant}
354-
dtype = deduce_diffusers_dtype(
355-
model_name_or_path,
356-
revision=revision,
357-
cache_dir=cache_dir,
358-
token=token,
359-
local_files_only=local_files_only,
360-
force_download=force_download,
361-
trust_remote_code=trust_remote_code,
362-
**_loading_kwargs,
363-
)
356+
if dtype == "auto" or dtype is None:
357+
dtype = deduce_diffusers_dtype(
358+
model_name_or_path,
359+
revision=revision,
360+
cache_dir=cache_dir,
361+
token=token,
362+
local_files_only=local_files_only,
363+
force_download=force_download,
364+
trust_remote_code=trust_remote_code,
365+
**_loading_kwargs,
366+
)
367+
if (
368+
dtype in {torch.bfloat16, torch.float16}
369+
and ov_config is not None
370+
and ov_config.dtype in {"fp16", "fp32"}
371+
):
372+
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
364373
if dtype in [torch.float16, torch.bfloat16]:
365374
loading_kwargs["torch_dtype"] = dtype
366375
patch_16bit = True
376+
if loading_kwargs.get("torch_dtype") == "auto":
377+
loading_kwargs["torch_dtype"] = dtype
367378

368379
try:
369380
if library_name == "open_clip":

optimum/intel/openvino/modeling_diffusion.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,14 @@ def _from_transformers(
593593
if load_in_8bit is None and not quantization_config:
594594
ov_config = None
595595
else:
596-
ov_config = OVConfig(dtype="fp32")
596+
ov_config = OVConfig(dtype="auto")
597+
598+
torch_dtype = kwargs.pop("torch_dtype", None)
599+
600+
model_loading_kwargs = {}
601+
602+
if torch_dtype is not None:
603+
model_loading_kwargs["torch_dtype"] = torch_dtype
597604

598605
model_save_dir = TemporaryDirectory()
599606
model_save_path = Path(model_save_dir.name)
@@ -613,6 +620,7 @@ def _from_transformers(
613620
ov_config=ov_config,
614621
library_name=cls._library_name,
615622
variant=variant,
623+
model_loading_kwargs=model_loading_kwargs,
616624
)
617625

618626
return cls._from_pretrained(

0 commit comments

Comments
 (0)