Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-enable decoder sequence classification #1679

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,23 +410,19 @@ def main_export(
**loading_kwargs,
)

needs_pad_token_id = (
task == "text-classification"
and getattr(model.config, "pad_token_id", None)
and getattr(model.config, "is_decoder", False)
)
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None

if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
try:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
model.config.pad_token_id = tok.pad_token_id
except Exception:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)
model.config.pad_token_id = pad_token_id

if "stable-diffusion" in task:
model_type = "stable-diffusion"
Expand Down
10 changes: 0 additions & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,6 @@ class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

@property
def values_override(self) -> Optional[Dict[str, Any]]:
pad_value_override = {}
if not getattr(self._config, "pad_token_id", None):
pad_value_override = {"pad_token_id": 0}
super_values_override = super().values_override
if super_values_override:
return {**super_values_override, **pad_value_override}
return pad_value_override


class GPTJOnnxConfig(GPT2OnnxConfig):
pass
Expand Down
17 changes: 10 additions & 7 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
"token-classification",
onnx="GPT2OnnxConfig",
),
Expand All @@ -612,7 +612,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
"token-classification",
onnx="GPTBigCodeOnnxConfig",
),
Expand All @@ -622,22 +622,23 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="GPTJOnnxConfig",
),
"gpt-neo": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="GPTNeoOnnxConfig",
),
"gpt-neox": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="GPTNeoXOnnxConfig",
),
"groupvit": supported_tasks_mapping(
Expand Down Expand Up @@ -734,7 +735,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification",
"text-classification",
onnx="MistralOnnxConfig",
),
# TODO: enable once the missing operator is supported.
Expand Down Expand Up @@ -782,6 +783,7 @@ class TasksManager:
"mpt": supported_tasks_mapping(
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="MPTOnnxConfig",
),
"mt5": supported_tasks_mapping(
Expand Down Expand Up @@ -818,15 +820,15 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="OPTOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"text-classification",
onnx="LlamaOnnxConfig",
),
"pegasus": supported_tasks_mapping(
Expand All @@ -849,6 +851,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="PhiOnnxConfig",
),
"pix2struct": supported_tasks_mapping(
Expand Down
7 changes: 7 additions & 0 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def _onnx_export(
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict] = None,
):
# We need to set this to some value to be able to test the outputs values for batch size > 1.
if task == "text-classification":
pad_token_id = 0
else:
pad_token_id = None

with TemporaryDirectory() as tmpdir:
try:
main_export(
Expand All @@ -198,6 +204,7 @@ def _onnx_export(
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=no_dynamic_axes,
pad_token_id=pad_token_id,
model_kwargs=model_kwargs,
)
except MinimumVersionError as e:
Expand Down
3 changes: 0 additions & 3 deletions tests/exporters/onnx/test_onnx_config_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def test_onnx_config_with_loss(self):
gc.collect()

def test_onnx_decoder_model_with_config_with_loss(self):
self.skipTest(
"Skipping due to a bug introduced in transformers with https://github.com/huggingface/transformers/pull/24979, argmax on int64 is not supported by ONNX"
)
with tempfile.TemporaryDirectory() as tmp_dir:
# Prepare model and dataset
model_checkpoint = "hf-internal-testing/tiny-random-gpt2"
Expand Down