@@ -1671,16 +1671,17 @@ def _infer_library_from_model(
1671
1671
if library_name is not None :
1672
1672
return library_name
1673
1673
1674
- if (
1674
+ # SentenceTransformer models have no config attributes
1675
+ if hasattr (model , "_model_config" ):
1676
+ library_name = "sentence_transformers"
1677
+ elif (
1675
1678
hasattr (model , "pretrained_cfg" )
1676
1679
or hasattr (model .config , "pretrained_cfg" )
1677
1680
or hasattr (model .config , "architecture" )
1678
1681
):
1679
1682
library_name = "timm"
1680
1683
elif hasattr (model .config , "_diffusers_version" ) or getattr (model , "config_name" , "" ) == "model_index.json" :
1681
1684
library_name = "diffusers"
1682
- elif hasattr (model , "_model_config" ):
1683
- library_name = "sentence_transformers"
1684
1685
else :
1685
1686
library_name = "transformers"
1686
1687
return library_name
@@ -1905,7 +1906,6 @@ def get_model_from_task(
1905
1906
model_class = TasksManager .get_model_class_for_task (
1906
1907
task , framework , model_type = model_type , model_class_name = model_class_name , library = library_name
1907
1908
)
1908
-
1909
1909
if library_name == "timm" :
1910
1910
model = model_class (f"hf_hub:{ model_name_or_path } " , pretrained = True , exportable = True )
1911
1911
model = model .to (torch_dtype ).to (device )
0 commit comments