Skip to content

Commit ec85aa9

Browse files
authored
Fix library detection (#1690)
* fix task detection * remove unnecessary workflow
1 parent 96c6d48 commit ec85aa9

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

optimum/exporters/tasks.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,11 @@ def infer_library_from_model(
16551655

16561656
if "model_index.json" in all_files:
16571657
library_name = "diffusers"
1658+
elif (
1659+
any(file_path.startswith("sentence_") for file_path in all_files)
1660+
or "config_sentence_transformers.json" in all_files
1661+
):
1662+
library_name = "sentence_transformers"
16581663
elif CONFIG_NAME in all_files:
16591664
# We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
16601665
kwargs = {
@@ -1671,11 +1676,6 @@ def infer_library_from_model(
16711676
library_name = "diffusers"
16721677
else:
16731678
library_name = "transformers"
1674-
elif (
1675-
any(file_path.startswith("sentence_") for file_path in all_files)
1676-
or "config_sentence_transformers.json" in all_files
1677-
):
1678-
library_name = "sentence_transformers"
16791679
else:
16801680
library_name = "transformers"
16811681

tests/exporters/common/test_tasks_manager.py

+10
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,13 @@ def test_custom_class(self):
177177

178178
model = TasksManager.get_model_from_task("question-answering", "uclanlp/visualbert-vqa")
179179
self.assertTrue(isinstance(model, VisualBertForQuestionAnswering))
180+
181+
def test_library_detection(self):
182+
self.assertEqual(
183+
TasksManager.infer_library_from_model("intfloat/multilingual-e5-large"), "sentence_transformers"
184+
)
185+
self.assertEqual(
186+
TasksManager.infer_library_from_model("stabilityai/stable-diffusion-xl-base-1.0"), "diffusers"
187+
)
188+
self.assertEqual(TasksManager.infer_library_from_model("gpt2"), "transformers")
189+
self.assertEqual(TasksManager.infer_library_from_model("timm/mobilenetv3_large_100.ra_in1k"), "timm")

0 commit comments

Comments
 (0)