Skip to content

Commit 7d5e3c6

Browse files
committed
rework logic for dtype handling
1 parent 51b4e0c commit 7d5e3c6

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

optimum/exporters/openvino/__main__.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def main_export(
240240
loading_kwargs = model_loading_kwargs or {}
241241
if variant is not None:
242242
loading_kwargs["variant"] = variant
243+
dtype = loading_kwargs.get("torch_dtype", None)
244+
if isinstance(dtype, str):
245+
dtype = getattr(torch, dtype) if dtype != "auto" else dtype
243246
if library_name == "transformers":
244247
config = AutoConfig.from_pretrained(
245248
model_name_or_path,
@@ -302,9 +305,8 @@ def main_export(
302305
"Please provide custom export config if you want load model with remote code."
303306
)
304307
trust_remote_code = False
305-
dtype = loading_kwargs.get("torch_dtype")
306-
if isinstance(dtype, str):
307-
dtype = getattr(config, "torch_dtype") if dtype == "auto" else getattr(torch, dtype)
308+
if dtype == "auto":
309+
dtype = getattr(config, "torch_dtype")
308310

309311
if (
310312
dtype is None
@@ -351,12 +353,7 @@ class StoreAttr(object):
351353
GPTQQuantizer.post_init_model = post_init_model
352354
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
353355
_loading_kwargs = {} if variant is None else {"variant": variant}
354-
dtype = loading_kwargs.pop("torch_dtype", None)
355-
if isinstance(dtype, str):
356-
dtype = None if dtype == "auto" else getattr(torch, dtype)
357-
if ov_config is not None and ov_config.dtype in {"fp16", "fp32"}:
358-
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
359-
if dtype is None:
356+
if dtype == "auto" or dtype is None:
360357
dtype = deduce_diffusers_dtype(
361358
model_name_or_path,
362359
revision=revision,
@@ -367,9 +364,17 @@ class StoreAttr(object):
367364
trust_remote_code=trust_remote_code,
368365
**_loading_kwargs,
369366
)
367+
if (
368+
dtype in {torch.bfloat16, torch.float16}
369+
and ov_config is not None
370+
and ov_config.dtype in {"fp16", "fp32"}
371+
):
372+
dtype = torch.float16 if ov_config.dtype == "fp16" else torch.float32
370373
if dtype in [torch.float16, torch.bfloat16]:
371374
loading_kwargs["torch_dtype"] = dtype
372375
patch_16bit = True
376+
if loading_kwargs.get("torch_dtype") == "auto":
377+
loading_kwargs["torch_dtype"] = dtype
373378

374379
try:
375380
if library_name == "open_clip":

0 commit comments

Comments
 (0)