Skip to content

Commit 65a8a94

Browse files
Add ONNX Support for Decision Transformer Model (#2038)
* Decision Transformer to ONNX V0.1 * Decision Transformer to ONNX V0.2 * Update optimum/exporters/onnx/model_configs.py * Apply suggestions from code review * Update optimum/exporters/onnx/base.py * Update optimum/exporters/onnx/model_configs.py * Update optimum/utils/input_generators.py * Update optimum/exporters/onnx/model_configs.py * Apply suggestions from code review * Update optimum/exporters/tasks.py * ONNXToDT: changes to order of OrderedDict elements * make style changes * test * remove custom normalized config * remove unncessary dynamic axes --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: IlyasMoutawwakil <moutawwakil.ilyas.tsi@gmail.com>
1 parent d2a5a6a commit 65a8a94

File tree

6 files changed

+74
-0
lines changed

6 files changed

+74
-0
lines changed

docs/source/exporters/onnx/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
3636
- Data2VecVision
3737
- Deberta
3838
- Deberta-v2
39+
- Decision Transformer
3940
- Deit
4041
- Detr
4142
- DistilBert

optimum/exporters/onnx/model_configs.py

+25
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
BloomDummyPastKeyValuesGenerator,
2828
DummyAudioInputGenerator,
2929
DummyCodegenDecoderTextInputGenerator,
30+
DummyDecisionTransformerInputGenerator,
3031
DummyDecoderTextInputGenerator,
3132
DummyEncodecInputGenerator,
3233
DummyFluxTransformerTextInputGenerator,
@@ -263,6 +264,30 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
263264
pass
264265

265266

267+
class DecisionTransformerOnnxConfig(OnnxConfig):
268+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyDecisionTransformerInputGenerator,)
269+
NORMALIZED_CONFIG_CLASS = NormalizedConfig
270+
271+
@property
272+
def inputs(self) -> Dict[str, Dict[int, str]]:
273+
return {
274+
"states": {0: "batch_size", 1: "sequence_length"},
275+
"actions": {0: "batch_size", 1: "sequence_length"},
276+
"timesteps": {0: "batch_size", 1: "sequence_length"},
277+
"returns_to_go": {0: "batch_size", 1: "sequence_length"},
278+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
279+
}
280+
281+
@property
282+
def outputs(self) -> Dict[str, Dict[int, str]]:
283+
return {
284+
"state_preds": {0: "batch_size", 1: "sequence_length"},
285+
"action_preds": {0: "batch_size", 1: "sequence_length"},
286+
"return_preds": {0: "batch_size", 1: "sequence_length"},
287+
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
288+
}
289+
290+
266291
class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
267292
DEFAULT_ONNX_OPSET = 14
268293
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")

optimum/exporters/tasks.py

+9
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class TasksManager:
217217
"multiple-choice": "AutoModelForMultipleChoice",
218218
"object-detection": "AutoModelForObjectDetection",
219219
"question-answering": "AutoModelForQuestionAnswering",
220+
"reinforcement-learning": "AutoModel",
220221
"semantic-segmentation": "AutoModelForSemanticSegmentation",
221222
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
222223
"text-generation": "AutoModelForCausalLM",
@@ -574,6 +575,11 @@ class TasksManager:
574575
onnx="DebertaV2OnnxConfig",
575576
tflite="DebertaV2TFLiteConfig",
576577
),
578+
"decision-transformer": supported_tasks_mapping(
579+
"feature-extraction",
580+
"reinforcement-learning",
581+
onnx="DecisionTransformerOnnxConfig",
582+
),
577583
"deit": supported_tasks_mapping(
578584
"feature-extraction",
579585
"image-classification",
@@ -2085,6 +2091,9 @@ def get_model_from_task(
20852091
if original_task == "automatic-speech-recognition" or task == "automatic-speech-recognition":
20862092
if original_task == "auto" and config.architectures is not None:
20872093
model_class_name = config.architectures[0]
2094+
elif original_task == "reinforcement-learning" or task == "reinforcement-learning":
2095+
if config.architectures is not None:
2096+
model_class_name = config.architectures[0]
20882097

20892098
if library_name == "diffusers":
20902099
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)

optimum/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
DummyAudioInputGenerator,
5454
DummyBboxInputGenerator,
5555
DummyCodegenDecoderTextInputGenerator,
56+
DummyDecisionTransformerInputGenerator,
5657
DummyDecoderTextInputGenerator,
5758
DummyEncodecInputGenerator,
5859
DummyFluxTransformerTextInputGenerator,

optimum/utils/input_generators.py

+37
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,43 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
507507
)
508508

509509

510+
class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):
511+
"""
512+
Generates dummy decision transformer inputs.
513+
"""
514+
515+
SUPPORTED_INPUT_NAMES = (
516+
"states",
517+
"actions",
518+
"timesteps",
519+
"returns_to_go",
520+
"attention_mask",
521+
)
522+
523+
def __init__(self, *args, **kwargs):
524+
super().__init__(*args, **kwargs)
525+
self.act_dim = self.normalized_config.config.act_dim
526+
self.state_dim = self.normalized_config.config.state_dim
527+
self.max_ep_len = self.normalized_config.config.max_ep_len
528+
529+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
530+
if input_name == "states":
531+
shape = [self.batch_size, self.sequence_length, self.state_dim]
532+
elif input_name == "actions":
533+
shape = [self.batch_size, self.sequence_length, self.act_dim]
534+
elif input_name == "rewards":
535+
shape = [self.batch_size, self.sequence_length, 1]
536+
elif input_name == "returns_to_go":
537+
shape = [self.batch_size, self.sequence_length, 1]
538+
elif input_name == "attention_mask":
539+
shape = [self.batch_size, self.sequence_length]
540+
elif input_name == "timesteps":
541+
shape = [self.batch_size, self.sequence_length]
542+
return self.random_int_tensor(shape=shape, max_value=self.max_ep_len, framework=framework, dtype=int_dtype)
543+
544+
return self.random_float_tensor(shape, min_value=-2.0, max_value=2.0, framework=framework, dtype=float_dtype)
545+
546+
510547
class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
511548
SUPPORTED_INPUT_NAMES = (
512549
"decoder_input_ids",

tests/exporters/exporters_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"data2vec-audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
6868
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
6969
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
70+
"decision-transformer": "edbeeching/decision-transformer-gym-hopper-medium",
7071
"deit": "hf-internal-testing/tiny-random-DeiTModel",
7172
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
7273
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",

0 commit comments

Comments
 (0)