Skip to content

Commit 7d2f95e

Browse files
dwyatteyoung-developer
authored andcommittedMay 10, 2024
re-enable decoder sequence classification (huggingface#1679)
* re-enable decoder sequence classification * update tests * revert to better pad token handling logic * minor updates * format
1 parent 3dd0d93 commit 7d2f95e

File tree

5 files changed

+22
-29
lines changed

5 files changed

+22
-29
lines changed
 

‎optimum/exporters/onnx/__main__.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -410,23 +410,19 @@ def main_export(
410410
**loading_kwargs,
411411
)
412412

413-
needs_pad_token_id = (
414-
task == "text-classification"
415-
and getattr(model.config, "pad_token_id", None)
416-
and getattr(model.config, "is_decoder", False)
417-
)
413+
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None
418414

419415
if needs_pad_token_id:
420416
if pad_token_id is not None:
421417
model.config.pad_token_id = pad_token_id
422418
else:
423-
try:
424-
tok = AutoTokenizer.from_pretrained(model_name_or_path)
425-
model.config.pad_token_id = tok.pad_token_id
426-
except Exception:
419+
tok = AutoTokenizer.from_pretrained(model_name_or_path)
420+
pad_token_id = getattr(tok, "pad_token_id", None)
421+
if pad_token_id is None:
427422
raise ValueError(
428423
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
429424
)
425+
model.config.pad_token_id = pad_token_id
430426

431427
if "stable-diffusion" in task:
432428
model_type = "stable-diffusion"

‎optimum/exporters/onnx/model_configs.py

-10
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,6 @@ class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
203203
DEFAULT_ONNX_OPSET = 13
204204
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
205205

206-
@property
207-
def values_override(self) -> Optional[Dict[str, Any]]:
208-
pad_value_override = {}
209-
if not getattr(self._config, "pad_token_id", None):
210-
pad_value_override = {"pad_token_id": 0}
211-
super_values_override = super().values_override
212-
if super_values_override:
213-
return {**super_values_override, **pad_value_override}
214-
return pad_value_override
215-
216206

217207
class GPTJOnnxConfig(GPT2OnnxConfig):
218208
pass

‎optimum/exporters/tasks.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ class TasksManager:
603603
"feature-extraction-with-past",
604604
"text-generation",
605605
"text-generation-with-past",
606-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
606+
"text-classification",
607607
"token-classification",
608608
onnx="GPT2OnnxConfig",
609609
),
@@ -612,7 +612,7 @@ class TasksManager:
612612
"feature-extraction-with-past",
613613
"text-generation",
614614
"text-generation-with-past",
615-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
615+
"text-classification",
616616
"token-classification",
617617
onnx="GPTBigCodeOnnxConfig",
618618
),
@@ -622,22 +622,23 @@ class TasksManager:
622622
"text-generation",
623623
"text-generation-with-past",
624624
"question-answering",
625-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
625+
"text-classification",
626626
onnx="GPTJOnnxConfig",
627627
),
628628
"gpt-neo": supported_tasks_mapping(
629629
"feature-extraction",
630630
"feature-extraction-with-past",
631631
"text-generation",
632632
"text-generation-with-past",
633-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
633+
"text-classification",
634634
onnx="GPTNeoOnnxConfig",
635635
),
636636
"gpt-neox": supported_tasks_mapping(
637637
"feature-extraction",
638638
"feature-extraction-with-past",
639639
"text-generation",
640640
"text-generation-with-past",
641+
"text-classification",
641642
onnx="GPTNeoXOnnxConfig",
642643
),
643644
"groupvit": supported_tasks_mapping(
@@ -734,7 +735,7 @@ class TasksManager:
734735
"feature-extraction-with-past",
735736
"text-generation",
736737
"text-generation-with-past",
737-
# "text-classification",
738+
"text-classification",
738739
onnx="MistralOnnxConfig",
739740
),
740741
# TODO: enable once the missing operator is supported.
@@ -782,6 +783,7 @@ class TasksManager:
782783
"mpt": supported_tasks_mapping(
783784
"text-generation",
784785
"text-generation-with-past",
786+
"text-classification",
785787
onnx="MPTOnnxConfig",
786788
),
787789
"mt5": supported_tasks_mapping(
@@ -818,15 +820,15 @@ class TasksManager:
818820
"text-generation",
819821
"text-generation-with-past",
820822
"question-answering",
821-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
823+
"text-classification",
822824
onnx="OPTOnnxConfig",
823825
),
824826
"llama": supported_tasks_mapping(
825827
"feature-extraction",
826828
"feature-extraction-with-past",
827829
"text-generation",
828830
"text-generation-with-past",
829-
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
831+
"text-classification",
830832
onnx="LlamaOnnxConfig",
831833
),
832834
"pegasus": supported_tasks_mapping(
@@ -849,6 +851,7 @@ class TasksManager:
849851
"feature-extraction-with-past",
850852
"text-generation",
851853
"text-generation-with-past",
854+
"text-classification",
852855
onnx="PhiOnnxConfig",
853856
),
854857
"pix2struct": supported_tasks_mapping(

‎tests/exporters/onnx/test_exporters_onnx_cli.py

+7
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ def _onnx_export(
185185
no_dynamic_axes: bool = False,
186186
model_kwargs: Optional[Dict] = None,
187187
):
188+
# We need to set this to some value to be able to test the outputs values for batch size > 1.
189+
if task == "text-classification":
190+
pad_token_id = 0
191+
else:
192+
pad_token_id = None
193+
188194
with TemporaryDirectory() as tmpdir:
189195
try:
190196
main_export(
@@ -198,6 +204,7 @@ def _onnx_export(
198204
no_post_process=no_post_process,
199205
_variant=variant,
200206
no_dynamic_axes=no_dynamic_axes,
207+
pad_token_id=pad_token_id,
201208
model_kwargs=model_kwargs,
202209
)
203210
except MinimumVersionError as e:

‎tests/exporters/onnx/test_onnx_config_loss.py

-3
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ def test_onnx_config_with_loss(self):
123123
gc.collect()
124124

125125
def test_onnx_decoder_model_with_config_with_loss(self):
126-
self.skipTest(
127-
"Skipping due to a bug introduced in transformers with https://github.com/huggingface/transformers/pull/24979, argmax on int64 is not supported by ONNX"
128-
)
129126
with tempfile.TemporaryDirectory() as tmp_dir:
130127
# Prepare model and dataset
131128
model_checkpoint = "hf-internal-testing/tiny-random-gpt2"

0 commit comments

Comments
 (0)
Please sign in to comment.