Skip to content

Commit adcae38

Browse files
authored
fix infer task from model_name if model from sentence transformer (#2151)
* fix infer task from model_name if model from sentence transformer * use library_name for infer task
1 parent b9fa9aa commit adcae38

File tree

4 files changed

+34
-13
lines changed

4 files changed

+34
-13
lines changed

optimum/exporters/onnx/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def main_export(
256256

257257
if task == "auto":
258258
try:
259-
task = TasksManager.infer_task_from_model(model_name_or_path)
259+
task = TasksManager.infer_task_from_model(model_name_or_path, library_name=library_name)
260260
except KeyError as e:
261261
raise KeyError(
262262
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"

optimum/exporters/tasks.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,7 @@ def _infer_task_from_model_name_or_path(
17821782
revision: Optional[str] = None,
17831783
cache_dir: str = HUGGINGFACE_HUB_CACHE,
17841784
token: Optional[Union[bool, str]] = None,
1785+
library_name: Optional[str] = None,
17851786
) -> str:
17861787
inferred_task_name = None
17871788

@@ -1803,13 +1804,14 @@ def _infer_task_from_model_name_or_path(
18031804
raise RuntimeError(
18041805
f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})."
18051806
)
1806-
library_name = cls.infer_library_from_model(
1807-
model_name_or_path,
1808-
subfolder=subfolder,
1809-
revision=revision,
1810-
cache_dir=cache_dir,
1811-
token=token,
1812-
)
1807+
if library_name is None:
1808+
library_name = cls.infer_library_from_model(
1809+
model_name_or_path,
1810+
subfolder=subfolder,
1811+
revision=revision,
1812+
cache_dir=cache_dir,
1813+
token=token,
1814+
)
18131815

18141816
if library_name == "timm":
18151817
inferred_task_name = "image-classification"
@@ -1828,6 +1830,8 @@ def _infer_task_from_model_name_or_path(
18281830
break
18291831
if inferred_task_name is not None:
18301832
break
1833+
elif library_name == "sentence_transformers":
1834+
inferred_task_name = "feature-extraction"
18311835
elif library_name == "transformers":
18321836
pipeline_tag = model_info.pipeline_tag
18331837
transformers_info = model_info.transformersInfo
@@ -1864,6 +1868,7 @@ def infer_task_from_model(
18641868
revision: Optional[str] = None,
18651869
cache_dir: str = HUGGINGFACE_HUB_CACHE,
18661870
token: Optional[Union[bool, str]] = None,
1871+
library_name: Optional[str] = None,
18671872
) -> str:
18681873
"""
18691874
Infers the task from the model repo, model instance, or model class.
@@ -1882,7 +1887,9 @@ def infer_task_from_model(
18821887
token (`Optional[Union[bool,str]]`, defaults to `None`):
18831888
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
18841889
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
1885-
1890+
library_name (`Optional[str]`, defaults to `None`):
1891+
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should
1892+
none be provided.
18861893
Returns:
18871894
`str`: The task name automatically detected from the HF hub repo, model instance, or model class.
18881895
"""
@@ -1895,6 +1902,7 @@ def infer_task_from_model(
18951902
revision=revision,
18961903
cache_dir=cache_dir,
18971904
token=token,
1905+
library_name=library_name,
18981906
)
18991907
elif type(model) == type:
19001908
inferred_task_name = cls._infer_task_from_model_or_model_class(model_class=model)
@@ -2170,6 +2178,9 @@ def get_model_from_task(
21702178
none be provided.
21712179
model_kwargs (`Dict[str, Any]`, *optional*):
21722180
Keyword arguments to pass to the model `.from_pretrained()` method.
2181+
library_name (`Optional[str]`, defaults to `None`):
2182+
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should
2183+
none be provided.
21732184
21742185
Returns:
21752186
The instance of the model.
@@ -2189,7 +2200,12 @@ def get_model_from_task(
21892200
original_task = task
21902201
if task == "auto":
21912202
task = TasksManager.infer_task_from_model(
2192-
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
2203+
model_name_or_path,
2204+
subfolder=subfolder,
2205+
revision=revision,
2206+
cache_dir=cache_dir,
2207+
token=token,
2208+
library_name=library_name,
21932209
)
21942210

21952211
model_type = None

optimum/exporters/tflite/__main__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
task = args.task
4747
if task == "auto":
4848
try:
49-
task = TasksManager.infer_task_from_model(args.model)
49+
task = TasksManager.infer_task_from_model(args.model, library_name="transformers")
5050
except KeyError as e:
5151
raise KeyError(
5252
"The task could not be automatically inferred. Please provide the argument --task with the task "
@@ -58,7 +58,12 @@ def main():
5858
)
5959

6060
model = TasksManager.get_model_from_task(
61-
task, args.model, framework="tf", cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code
61+
task,
62+
args.model,
63+
framework="tf",
64+
cache_dir=args.cache_dir,
65+
trust_remote_code=args.trust_remote_code,
66+
library_name="transformers",
6267
)
6368

6469
tflite_config_constructor = TasksManager.get_exporter_config_constructor(

optimum/exporters/tflite/convert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def prepare_converter_for_quantization(
194194
if task is None:
195195
from ...exporters import TasksManager
196196

197-
task = TasksManager.infer_task_from_model(model)
197+
task = TasksManager.infer_task_from_model(model, library_name="transformers")
198198

199199
preprocessor_kwargs = {}
200200
if isinstance(preprocessor, PreTrainedTokenizerBase):

0 commit comments

Comments
 (0)