Skip to content

Commit c55f882

Browse files
authored
Fix infer library for sentence transformers models (#1832)
1 parent 3b5c486 commit c55f882

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

optimum/exporters/tasks.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1671,16 +1671,17 @@ def _infer_library_from_model(
16711671
if library_name is not None:
16721672
return library_name
16731673

1674-
if (
1674+
# SentenceTransformer models have no config attributes
1675+
if hasattr(model, "_model_config"):
1676+
library_name = "sentence_transformers"
1677+
elif (
16751678
hasattr(model, "pretrained_cfg")
16761679
or hasattr(model.config, "pretrained_cfg")
16771680
or hasattr(model.config, "architecture")
16781681
):
16791682
library_name = "timm"
16801683
elif hasattr(model.config, "_diffusers_version") or getattr(model, "config_name", "") == "model_index.json":
16811684
library_name = "diffusers"
1682-
elif hasattr(model, "_model_config"):
1683-
library_name = "sentence_transformers"
16841685
else:
16851686
library_name = "transformers"
16861687
return library_name
@@ -1905,7 +1906,6 @@ def get_model_from_task(
19051906
model_class = TasksManager.get_model_class_for_task(
19061907
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
19071908
)
1908-
19091909
if library_name == "timm":
19101910
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
19111911
model = model.to(torch_dtype).to(device)

0 commit comments

Comments
 (0)