Skip to content

Commit 9201bd4

Browse files
authoredFeb 8, 2024
Update standardize_model_attributes (#1686)
* update files * update library getter * fixed typo * fix library name detection * fix library name * update todo
1 parent 32a51af commit 9201bd4

File tree

5 files changed

+47
-96
lines changed

5 files changed

+47
-96
lines changed
 

‎optimum/exporters/onnx/__main__.py

-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def main_export(
211211
task = TasksManager.map_from_synonym(task)
212212

213213
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
214-
215214
library_name = TasksManager.infer_library_from_model(
216215
model_name_or_path, subfolder=subfolder, library_name=library_name
217216
)

‎optimum/exporters/onnx/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def onnx_export_from_model(
988988
"""
989989
library_name = TasksManager._infer_library_from_model(model)
990990

991-
# TODO: call standardize_model_attributes here once its model_name_or_path argument is optional.
991+
TasksManager.standardize_model_attributes(model, library_name)
992992

993993
if hasattr(model.config, "export_model_type"):
994994
model_type = model.config.export_model_type.replace("_", "-")

‎optimum/exporters/onnx/model_configs.py

-17
Original file line numberDiff line numberDiff line change
@@ -781,23 +781,6 @@ class TimmDefaultOnnxConfig(ViTOnnxConfig):
781781
ATOL_FOR_VALIDATION = 1e-3
782782
DEFAULT_ONNX_OPSET = 12
783783

784-
def __init__(
785-
self,
786-
config: "PretrainedConfig",
787-
task: str = "feature-extraction",
788-
preprocessors: Optional[List[Any]] = None,
789-
int_dtype: str = "int64",
790-
float_dtype: str = "fp32",
791-
legacy: bool = False,
792-
):
793-
super().__init__(config, task, preprocessors, int_dtype, float_dtype, legacy)
794-
795-
pretrained_cfg = self._config
796-
if hasattr(self._config, "pretrained_cfg"):
797-
pretrained_cfg = self._config.pretrained_cfg
798-
799-
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(pretrained_cfg)
800-
801784
def rename_ambiguous_inputs(self, inputs):
802785
# The input name in the model signature is `x, hence the export input name is updated.
803786
model_inputs = {}

‎optimum/exporters/tasks.py

+44-75
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import importlib
1818
import inspect
1919
import itertools
20-
import json
2120
import os
2221
from functools import partial
2322
from pathlib import Path
@@ -1523,10 +1522,12 @@ def _infer_task_from_model_name_or_path(
15231522
"Cannot infer the task from a model repo with a subfolder yet, please specify the task manually."
15241523
)
15251524
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
1526-
if getattr(model_info, "library_name", None) == "diffusers":
1525+
library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder, revision)
1526+
1527+
if library_name == "diffusers":
15271528
class_name = model_info.config["diffusers"]["class_name"]
15281529
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
1529-
elif getattr(model_info, "library_name", None) == "timm":
1530+
elif library_name == "timm":
15301531
inferred_task_name = "image-classification"
15311532
else:
15321533
pipeline_tag = getattr(model_info, "pipeline_tag", None)
@@ -1544,13 +1545,9 @@ def _infer_task_from_model_name_or_path(
15441545
# transformersInfo does not always have a pipeline_tag attribute
15451546
class_name_prefix = ""
15461547
if is_torch_available():
1547-
tasks_to_automodels = TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP[
1548-
model_info.library_name
1549-
]
1548+
tasks_to_automodels = TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP[library_name]
15501549
else:
1551-
tasks_to_automodels = TasksManager._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP[
1552-
model_info.library_name
1553-
]
1550+
tasks_to_automodels = TasksManager._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP[library_name]
15541551
class_name_prefix = "TF"
15551552

15561553
auto_model_class_name = transformers_info["auto_model"]
@@ -1603,8 +1600,17 @@ def infer_task_from_model(
16031600
return task
16041601

16051602
@staticmethod
1606-
def _infer_library_from_model(model: Union["PreTrainedModel", "TFPreTrainedModel"]):
1607-
if hasattr(model.config, "pretrained_cfg") or hasattr(model.config, "architecture"):
1603+
def _infer_library_from_model(
1604+
model: Union["PreTrainedModel", "TFPreTrainedModel"], library_name: Optional[str] = None
1605+
):
1606+
if library_name is not None:
1607+
return library_name
1608+
1609+
if (
1610+
hasattr(model, "pretrained_cfg")
1611+
or hasattr(model.config, "pretrained_cfg")
1612+
or hasattr(model.config, "architecture")
1613+
):
16081614
library_name = "timm"
16091615
elif hasattr(model.config, "_diffusers_version") or getattr(model, "config_name", "") == "model_index.json":
16101616
library_name = "diffusers"
@@ -1645,44 +1651,33 @@ def infer_library_from_model(
16451651
if library_name is not None:
16461652
return library_name
16471653

1648-
full_model_path = Path(model_name_or_path) / subfolder
1649-
1650-
if not full_model_path.is_dir():
1651-
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
1652-
library_name = getattr(model_info, "library_name", None)
1653-
1654-
# sentence-transformers package name is sentence_transformers
1655-
if library_name is not None:
1656-
library_name = library_name.replace("-", "_")
1654+
all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)
16571655

1658-
if library_name is None:
1659-
all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)
1656+
if "model_index.json" in all_files:
1657+
library_name = "diffusers"
1658+
elif CONFIG_NAME in all_files:
1659+
# We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
1660+
kwargs = {
1661+
"subfolder": subfolder,
1662+
"revision": revision,
1663+
"cache_dir": cache_dir,
1664+
}
1665+
config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs)
1666+
model_config = PretrainedConfig.from_dict(config_dict, **kwargs)
16601667

1661-
if "model_index.json" in all_files:
1668+
if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"):
1669+
library_name = "timm"
1670+
elif hasattr(model_config, "_diffusers_version"):
16621671
library_name = "diffusers"
1663-
elif CONFIG_NAME in all_files:
1664-
# We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type.
1665-
kwargs = {
1666-
"subfolder": subfolder,
1667-
"revision": revision,
1668-
"cache_dir": cache_dir,
1669-
}
1670-
config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs)
1671-
model_config = PretrainedConfig.from_dict(config_dict, **kwargs)
1672-
1673-
if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"):
1674-
library_name = "timm"
1675-
elif hasattr(model_config, "_diffusers_version"):
1676-
library_name = "diffusers"
1677-
else:
1678-
library_name = "transformers"
1679-
elif (
1680-
any(file_path.startswith("sentence_") for file_path in all_files)
1681-
or "config_sentence_transformers.json" in all_files
1682-
):
1683-
library_name = "sentence_transformers"
16841672
else:
16851673
library_name = "transformers"
1674+
elif (
1675+
any(file_path.startswith("sentence_") for file_path in all_files)
1676+
or "config_sentence_transformers.json" in all_files
1677+
):
1678+
library_name = "sentence_transformers"
1679+
else:
1680+
library_name = "transformers"
16861681

16871682
if library_name is None:
16881683
raise ValueError(
@@ -1694,11 +1689,7 @@ def infer_library_from_model(
16941689
@classmethod
16951690
def standardize_model_attributes(
16961691
cls,
1697-
model_name_or_path: Union[str, Path],
16981692
model: Union["PreTrainedModel", "TFPreTrainedModel"],
1699-
subfolder: str = "",
1700-
revision: Optional[str] = None,
1701-
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
17021693
library_name: Optional[str] = None,
17031694
):
17041695
"""
@@ -1721,40 +1712,20 @@ def standardize_model_attributes(
17211712
library_name (`Optional[str]`, *optional*)::
17221713
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers".
17231714
"""
1724-
# TODO: make model_name_or_path an optional argument here.
1725-
1726-
library_name = TasksManager.infer_library_from_model(
1727-
model_name_or_path, subfolder, revision, cache_dir, library_name
1728-
)
1729-
1730-
full_model_path = Path(model_name_or_path) / subfolder
1731-
is_local = full_model_path.is_dir()
1715+
library_name = TasksManager._infer_library_from_model(model, library_name)
17321716

17331717
if library_name == "diffusers":
17341718
model.config.export_model_type = "stable-diffusion"
17351719
elif library_name == "timm":
17361720
# Retrieve model config
1737-
config_path = full_model_path / "config.json"
1738-
1739-
if not is_local:
1740-
config_path = huggingface_hub.hf_hub_download(
1741-
model_name_or_path, "config.json", subfolder=subfolder, revision=revision
1742-
)
1743-
1744-
model_config = PretrainedConfig.from_json_file(config_path)
1745-
1746-
if hasattr(model_config, "pretrained_cfg"):
1747-
model_config.pretrained_cfg = PretrainedConfig.from_dict(model_config.pretrained_cfg)
1721+
model_config = PretrainedConfig.from_dict(model.pretrained_cfg)
17481722

17491723
# Set config as in transformers
17501724
setattr(model, "config", model_config)
17511725

1752-
# Update model_type for model
1753-
with open(config_path) as fp:
1754-
model_type = json.load(fp)["architecture"]
1755-
17561726
# `model_type` is a class attribute in Transformers, let's avoid modifying it.
1757-
model.config.export_model_type = model_type
1727+
model.config.export_model_type = model.pretrained_cfg["architecture"]
1728+
17581729
elif library_name == "sentence_transformers":
17591730
if "Transformer" in model[0].__class__.__name__:
17601731
model.config = model[0].auto_model.config
@@ -1903,9 +1874,7 @@ def get_model_from_task(
19031874
kwargs["from_pt"] = True
19041875
model = model_class.from_pretrained(model_name_or_path, **kwargs)
19051876

1906-
TasksManager.standardize_model_attributes(
1907-
model_name_or_path, model, subfolder, revision, cache_dir, library_name
1908-
)
1877+
TasksManager.standardize_model_attributes(model, library_name)
19091878

19101879
return model
19111880

‎tests/exporters/onnx/test_onnx_export.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _onnx_export(
185185
if library_name == "timm":
186186
model_class = TasksManager.get_model_class_for_task(task, library=library_name)
187187
model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
188-
TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
188+
TasksManager.standardize_model_attributes(model, library_name=library_name)
189189
else:
190190
config = AutoConfig.from_pretrained(model_name)
191191
model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-"))
@@ -611,7 +611,7 @@ def _onnx_export(
611611
if library_name == "timm":
612612
model_class = TasksManager.get_model_class_for_task(task, library=library_name)
613613
model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
614-
TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
614+
TasksManager.standardize_model_attributes(model, library_name=library_name)
615615
else:
616616
config = AutoConfig.from_pretrained(model_name)
617617
model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-"))

0 commit comments

Comments
 (0)