Skip to content

Commit 5584eb8

Browse files
authored
Fix infer task for stable diffusion (#1793)
* fix * apply suggestions
1 parent 253c6c2 commit 5584eb8

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

optimum/exporters/tasks.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,14 @@ def _infer_task_from_model_name_or_path(
15601560
library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder, revision)
15611561

15621562
if library_name == "diffusers":
1563-
class_name = model_info.config["diffusers"]["class_name"]
1563+
if model_info.config["diffusers"].get("class_name", None):
1564+
class_name = model_info.config["diffusers"]["class_name"]
1565+
elif model_info.config["diffusers"].get("_class_name", None):
1566+
class_name = model_info.config["diffusers"]["_class_name"]
1567+
else:
1568+
raise ValueError(
1569+
f"Could not automatically infer the class name for {model_name_or_path}. Please open an issue at https://github.com/huggingface/optimum/issues."
1570+
)
15641571
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
15651572
elif library_name == "timm":
15661573
inferred_task_name = "image-classification"

0 commit comments

Comments
 (0)