Skip to content

Commit 2a789d6

Browse files
authoredFeb 6, 2024
ORTModelForFeatureExtraction always exports as transformers models (#1684)
fix
1 parent da6f9e2 commit 2a789d6

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed
 

‎optimum/onnxruntime/modeling_ort.py

+54
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,60 @@ def forward(
10021002
# converts output to namedtuple for pipelines post-processing
10031003
return BaseModelOutput(last_hidden_state=last_hidden_state)
10041004

1005+
@classmethod
1006+
def _export(
1007+
cls,
1008+
model_id: str,
1009+
config: "PretrainedConfig",
1010+
use_auth_token: Optional[Union[bool, str]] = None,
1011+
revision: Optional[str] = None,
1012+
force_download: bool = False,
1013+
cache_dir: Optional[str] = None,
1014+
subfolder: str = "",
1015+
local_files_only: bool = False,
1016+
trust_remote_code: bool = False,
1017+
provider: str = "CPUExecutionProvider",
1018+
session_options: Optional[ort.SessionOptions] = None,
1019+
provider_options: Optional[Dict[str, Any]] = None,
1020+
use_io_binding: Optional[bool] = None,
1021+
task: Optional[str] = None,
1022+
) -> "ORTModel":
1023+
if task is None:
1024+
task = cls._auto_model_to_task(cls.auto_model_class)
1025+
1026+
save_dir = TemporaryDirectory()
1027+
save_dir_path = Path(save_dir.name)
1028+
1029+
# ORTModelForFeatureExtraction works with Transformers type of models, thus even sentence-transformers models are loaded as such.
1030+
main_export(
1031+
model_name_or_path=model_id,
1032+
output=save_dir_path,
1033+
task=task,
1034+
do_validation=False,
1035+
no_post_process=True,
1036+
subfolder=subfolder,
1037+
revision=revision,
1038+
cache_dir=cache_dir,
1039+
use_auth_token=use_auth_token,
1040+
local_files_only=local_files_only,
1041+
force_download=force_download,
1042+
trust_remote_code=trust_remote_code,
1043+
library_name="transformers",
1044+
)
1045+
1046+
config.save_pretrained(save_dir_path)
1047+
maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
1048+
1049+
return cls._from_pretrained(
1050+
save_dir_path,
1051+
config,
1052+
use_io_binding=use_io_binding,
1053+
model_save_dir=save_dir,
1054+
provider=provider,
1055+
session_options=session_options,
1056+
provider_options=provider_options,
1057+
)
1058+
10051059

10061060
MASKED_LM_EXAMPLE = r"""
10071061
Example of feature extraction:

‎tests/onnxruntime/test_modeling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def test_trust_remote_code(self):
12271227
class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
12281228
SUPPORTED_ARCHITECTURES = [
12291229
"albert",
1230-
"bart",
1230+
"bart",
12311231
"bert",
12321232
# "big_bird",
12331233
# "bigbird_pegasus",
@@ -1592,7 +1592,7 @@ def test_compare_to_io_binding(self, model_arch):
15921592
class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
15931593
SUPPORTED_ARCHITECTURES = [
15941594
"albert",
1595-
"bart",
1595+
"bart",
15961596
"bert",
15971597
# "big_bird",
15981598
# "bigbird_pegasus",

0 commit comments

Comments
 (0)
Please sign in to comment.