17
17
import importlib
18
18
import inspect
19
19
import itertools
20
- import json
21
20
import os
22
21
from functools import partial
23
22
from pathlib import Path
@@ -1523,10 +1522,12 @@ def _infer_task_from_model_name_or_path(
1523
1522
"Cannot infer the task from a model repo with a subfolder yet, please specify the task manually."
1524
1523
)
1525
1524
model_info = huggingface_hub .model_info (model_name_or_path , revision = revision )
1526
- if getattr (model_info , "library_name" , None ) == "diffusers" :
1525
+ library_name = TasksManager .infer_library_from_model (model_name_or_path , subfolder , revision )
1526
+
1527
+ if library_name == "diffusers" :
1527
1528
class_name = model_info .config ["diffusers" ]["class_name" ]
1528
1529
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
1529
- elif getattr ( model_info , " library_name" , None ) == "timm" :
1530
+ elif library_name == "timm" :
1530
1531
inferred_task_name = "image-classification"
1531
1532
else :
1532
1533
pipeline_tag = getattr (model_info , "pipeline_tag" , None )
@@ -1544,13 +1545,9 @@ def _infer_task_from_model_name_or_path(
1544
1545
# transformersInfo does not always have a pipeline_tag attribute
1545
1546
class_name_prefix = ""
1546
1547
if is_torch_available ():
1547
- tasks_to_automodels = TasksManager ._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP [
1548
- model_info .library_name
1549
- ]
1548
+ tasks_to_automodels = TasksManager ._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP [library_name ]
1550
1549
else :
1551
- tasks_to_automodels = TasksManager ._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP [
1552
- model_info .library_name
1553
- ]
1550
+ tasks_to_automodels = TasksManager ._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP [library_name ]
1554
1551
class_name_prefix = "TF"
1555
1552
1556
1553
auto_model_class_name = transformers_info ["auto_model" ]
@@ -1603,8 +1600,17 @@ def infer_task_from_model(
1603
1600
return task
1604
1601
1605
1602
@staticmethod
1606
- def _infer_library_from_model (model : Union ["PreTrainedModel" , "TFPreTrainedModel" ]):
1607
- if hasattr (model .config , "pretrained_cfg" ) or hasattr (model .config , "architecture" ):
1603
+ def _infer_library_from_model (
1604
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], library_name : Optional [str ] = None
1605
+ ):
1606
+ if library_name is not None :
1607
+ return library_name
1608
+
1609
+ if (
1610
+ hasattr (model , "pretrained_cfg" )
1611
+ or hasattr (model .config , "pretrained_cfg" )
1612
+ or hasattr (model .config , "architecture" )
1613
+ ):
1608
1614
library_name = "timm"
1609
1615
elif hasattr (model .config , "_diffusers_version" ) or getattr (model , "config_name" , "" ) == "model_index.json" :
1610
1616
library_name = "diffusers"
@@ -1645,44 +1651,33 @@ def infer_library_from_model(
1645
1651
if library_name is not None :
1646
1652
return library_name
1647
1653
1648
- full_model_path = Path (model_name_or_path ) / subfolder
1649
-
1650
- if not full_model_path .is_dir ():
1651
- model_info = huggingface_hub .model_info (model_name_or_path , revision = revision )
1652
- library_name = getattr (model_info , "library_name" , None )
1653
-
1654
- # sentence-transformers package name is sentence_transformers
1655
- if library_name is not None :
1656
- library_name = library_name .replace ("-" , "_" )
1654
+ all_files , _ = TasksManager .get_model_files (model_name_or_path , subfolder , cache_dir )
1657
1655
1658
- if library_name is None :
1659
- all_files , _ = TasksManager .get_model_files (model_name_or_path , subfolder , cache_dir )
1656
+ if "model_index.json" in all_files :
1657
+ library_name = "diffusers"
1658
+ elif CONFIG_NAME in all_files :
1659
+ # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
1660
+ kwargs = {
1661
+ "subfolder" : subfolder ,
1662
+ "revision" : revision ,
1663
+ "cache_dir" : cache_dir ,
1664
+ }
1665
+ config_dict , kwargs = PretrainedConfig .get_config_dict (model_name_or_path , ** kwargs )
1666
+ model_config = PretrainedConfig .from_dict (config_dict , ** kwargs )
1660
1667
1661
- if "model_index.json" in all_files :
1668
+ if hasattr (model_config , "pretrained_cfg" ) or hasattr (model_config , "architecture" ):
1669
+ library_name = "timm"
1670
+ elif hasattr (model_config , "_diffusers_version" ):
1662
1671
library_name = "diffusers"
1663
- elif CONFIG_NAME in all_files :
1664
- # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
1665
- kwargs = {
1666
- "subfolder" : subfolder ,
1667
- "revision" : revision ,
1668
- "cache_dir" : cache_dir ,
1669
- }
1670
- config_dict , kwargs = PretrainedConfig .get_config_dict (model_name_or_path , ** kwargs )
1671
- model_config = PretrainedConfig .from_dict (config_dict , ** kwargs )
1672
-
1673
- if hasattr (model_config , "pretrained_cfg" ) or hasattr (model_config , "architecture" ):
1674
- library_name = "timm"
1675
- elif hasattr (model_config , "_diffusers_version" ):
1676
- library_name = "diffusers"
1677
- else :
1678
- library_name = "transformers"
1679
- elif (
1680
- any (file_path .startswith ("sentence_" ) for file_path in all_files )
1681
- or "config_sentence_transformers.json" in all_files
1682
- ):
1683
- library_name = "sentence_transformers"
1684
1672
else :
1685
1673
library_name = "transformers"
1674
+ elif (
1675
+ any (file_path .startswith ("sentence_" ) for file_path in all_files )
1676
+ or "config_sentence_transformers.json" in all_files
1677
+ ):
1678
+ library_name = "sentence_transformers"
1679
+ else :
1680
+ library_name = "transformers"
1686
1681
1687
1682
if library_name is None :
1688
1683
raise ValueError (
@@ -1694,11 +1689,7 @@ def infer_library_from_model(
1694
1689
@classmethod
1695
1690
def standardize_model_attributes (
1696
1691
cls ,
1697
- model_name_or_path : Union [str , Path ],
1698
1692
model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
1699
- subfolder : str = "" ,
1700
- revision : Optional [str ] = None ,
1701
- cache_dir : str = huggingface_hub .constants .HUGGINGFACE_HUB_CACHE ,
1702
1693
library_name : Optional [str ] = None ,
1703
1694
):
1704
1695
"""
@@ -1721,40 +1712,20 @@ def standardize_model_attributes(
1721
1712
library_name (`Optional[str]`, *optional*)::
1722
1713
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers".
1723
1714
"""
1724
- # TODO: make model_name_or_path an optional argument here.
1725
-
1726
- library_name = TasksManager .infer_library_from_model (
1727
- model_name_or_path , subfolder , revision , cache_dir , library_name
1728
- )
1729
-
1730
- full_model_path = Path (model_name_or_path ) / subfolder
1731
- is_local = full_model_path .is_dir ()
1715
+ library_name = TasksManager ._infer_library_from_model (model , library_name )
1732
1716
1733
1717
if library_name == "diffusers" :
1734
1718
model .config .export_model_type = "stable-diffusion"
1735
1719
elif library_name == "timm" :
1736
1720
# Retrieve model config
1737
- config_path = full_model_path / "config.json"
1738
-
1739
- if not is_local :
1740
- config_path = huggingface_hub .hf_hub_download (
1741
- model_name_or_path , "config.json" , subfolder = subfolder , revision = revision
1742
- )
1743
-
1744
- model_config = PretrainedConfig .from_json_file (config_path )
1745
-
1746
- if hasattr (model_config , "pretrained_cfg" ):
1747
- model_config .pretrained_cfg = PretrainedConfig .from_dict (model_config .pretrained_cfg )
1721
+ model_config = PretrainedConfig .from_dict (model .pretrained_cfg )
1748
1722
1749
1723
# Set config as in transformers
1750
1724
setattr (model , "config" , model_config )
1751
1725
1752
- # Update model_type for model
1753
- with open (config_path ) as fp :
1754
- model_type = json .load (fp )["architecture" ]
1755
-
1756
1726
# `model_type` is a class attribute in Transformers, let's avoid modifying it.
1757
- model .config .export_model_type = model_type
1727
+ model .config .export_model_type = model .pretrained_cfg ["architecture" ]
1728
+
1758
1729
elif library_name == "sentence_transformers" :
1759
1730
if "Transformer" in model [0 ].__class__ .__name__ :
1760
1731
model .config = model [0 ].auto_model .config
@@ -1903,9 +1874,7 @@ def get_model_from_task(
1903
1874
kwargs ["from_pt" ] = True
1904
1875
model = model_class .from_pretrained (model_name_or_path , ** kwargs )
1905
1876
1906
- TasksManager .standardize_model_attributes (
1907
- model_name_or_path , model , subfolder , revision , cache_dir , library_name
1908
- )
1877
+ TasksManager .standardize_model_attributes (model , library_name )
1909
1878
1910
1879
return model
1911
1880
0 commit comments