File tree 2 files changed +15
-5
lines changed
2 files changed +15
-5
lines changed Original file line number Diff line number Diff line change @@ -1655,6 +1655,11 @@ def infer_library_from_model(
1655
1655
1656
1656
if "model_index.json" in all_files :
1657
1657
library_name = "diffusers"
1658
+ elif (
1659
+ any (file_path .startswith ("sentence_" ) for file_path in all_files )
1660
+ or "config_sentence_transformers.json" in all_files
1661
+ ):
1662
+ library_name = "sentence_transformers"
1658
1663
elif CONFIG_NAME in all_files :
1659
1664
# We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
1660
1665
kwargs = {
@@ -1671,11 +1676,6 @@ def infer_library_from_model(
1671
1676
library_name = "diffusers"
1672
1677
else :
1673
1678
library_name = "transformers"
1674
- elif (
1675
- any (file_path .startswith ("sentence_" ) for file_path in all_files )
1676
- or "config_sentence_transformers.json" in all_files
1677
- ):
1678
- library_name = "sentence_transformers"
1679
1679
else :
1680
1680
library_name = "transformers"
1681
1681
Original file line number Diff line number Diff line change @@ -177,3 +177,13 @@ def test_custom_class(self):
177
177
178
178
model = TasksManager .get_model_from_task ("question-answering" , "uclanlp/visualbert-vqa" )
179
179
self .assertTrue (isinstance (model , VisualBertForQuestionAnswering ))
180
+
181
+ def test_library_detection (self ):
182
+ self .assertEqual (
183
+ TasksManager .infer_library_from_model ("intfloat/multilingual-e5-large" ), "sentence_transformers"
184
+ )
185
+ self .assertEqual (
186
+ TasksManager .infer_library_from_model ("stabilityai/stable-diffusion-xl-base-1.0" ), "diffusers"
187
+ )
188
+ self .assertEqual (TasksManager .infer_library_from_model ("gpt2" ), "transformers" )
189
+ self .assertEqual (TasksManager .infer_library_from_model ("timm/mobilenetv3_large_100.ra_in1k" ), "timm" )
You can’t perform that action at this time.
0 commit comments