Skip to content

Commit f22655c

Browse files
Add RemBERT ONNX support (#2108)
* ONNX config for RemBERT added * added RemBERT to TasksManager * rembert added to exporters_utils * RemBERT added to test modelling tasks * changed rembert model * added RemBERT to test utils * Added RemBERT to documentation * Apply suggestions from code review --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent 28bd0ad commit f22655c

File tree

6 files changed

+22
-1
lines changed

6 files changed

+22
-1
lines changed

docs/source/exporters/onnx/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
8383
- PoolFormer
8484
- Qwen2(Qwen1.5)
8585
- RegNet
86+
- RemBERT
8687
- ResNet
8788
- Roberta
8889
- Roformer

optimum/exporters/onnx/model_configs.py

+4
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class SplinterOnnxConfig(BertOnnxConfig):
162162
DEFAULT_ONNX_OPSET = 11
163163

164164

165+
class RemBertOnnxConfig(BertOnnxConfig):
166+
DEFAULT_ONNX_OPSET = 11
167+
168+
165169
class DistilBertOnnxConfig(BertOnnxConfig):
166170
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for transformers>=4.46.0
167171

optimum/exporters/tasks.py

+9
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,15 @@ class TasksManager:
431431
onnx="BertOnnxConfig",
432432
tflite="BertTFLiteConfig",
433433
),
434+
"rembert": supported_tasks_mapping(
435+
"fill-mask",
436+
"feature-extraction",
437+
"text-classification",
438+
"multiple-choice",
439+
"token-classification",
440+
"question-answering",
441+
onnx="RemBertOnnxConfig",
442+
),
434443
# For big-bird and bigbird-pegasus being unsupported, refer to model_configs.py
435444
# "big-bird": supported_tasks_mapping(
436445
# "feature-extraction",

tests/exporters/exporters_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
139139
"pix2struct": "fxmarty/pix2struct-tiny-random",
140140
# "rembert": "google/rembert",
141+
"rembert": "hf-internal-testing/tiny-random-RemBertModel",
141142
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
142143
"qwen2": "fxmarty/tiny-dummy-qwen2",
143144
"regnet": "hf-internal-testing/tiny-random-RegNetModel",
@@ -257,7 +258,7 @@
257258
"owlv2": "google/owlv2-base-patch16",
258259
"owlvit": "google/owlvit-base-patch32",
259260
"perceiver": "hf-internal-testing/tiny-random-PerceiverModel", # Not using deepmind/language-perceiver because it takes too much time for testing.
260-
# "rembert": "google/rembert",
261+
"rembert": "google/rembert",
261262
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
262263
"regnet": "facebook/regnet-y-040",
263264
"resnet": "microsoft/resnet-50",

tests/onnxruntime/test_modeling.py

+5
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
13121312
"squeezebert",
13131313
"xlm_qa",
13141314
"xlm_roberta",
1315+
"rembert",
13151316
]
13161317

13171318
FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
@@ -1502,6 +1503,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin):
15021503
"squeezebert",
15031504
"xlm",
15041505
"xlm_roberta",
1506+
"rembert",
15051507
]
15061508

15071509
FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
@@ -1682,6 +1684,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
16821684
"squeezebert",
16831685
"xlm",
16841686
"xlm_roberta",
1687+
"rembert",
16851688
]
16861689

16871690
FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
@@ -1882,6 +1885,7 @@ class ORTModelForTokenClassificationIntegrationTest(ORTModelTestMixin):
18821885
"squeezebert",
18831886
"xlm",
18841887
"xlm_roberta",
1888+
"rembert",
18851889
]
18861890

18871891
FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
@@ -2227,6 +2231,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin):
22272231
"squeezebert",
22282232
"xlm",
22292233
"xlm_roberta",
2234+
"rembert",
22302235
]
22312236

22322237
FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}

tests/onnxruntime/utils_onnxruntime_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"pix2struct": "fxmarty/pix2struct-tiny-random",
136136
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
137137
"qwen2": "fxmarty/tiny-dummy-qwen2",
138+
"rembert": "hf-internal-testing/tiny-random-RemBertModel",
138139
"resnet": "hf-internal-testing/tiny-random-resnet",
139140
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
140141
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",

0 commit comments

Comments
 (0)