Skip to content

Commit b9fa9aa

Browse files
xenovaecharlaix
andauthored
Add ONNX export support for PatchTST (#2101)
* Add ONNX export support for `PatchTST` * Add unit test for patchtst * Add listed support for PatchTST * Add ONNX export support for patchtsmixer * Add task=feature-extraction * Fix ONNX compatible unfold * Formatting * Correctly handle negative indexing for onnx compatible unfold * Update tests/exporters/exporters_utils.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update optimum/exporters/tasks.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Move dummy patch tst input generator to input_generators.py * Code formatting --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 8d1347f commit b9fa9aa

File tree

9 files changed

+125
-2
lines changed

9 files changed

+125
-2
lines changed

docs/source/exporters/onnx/overview.mdx

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
8282
- OLMo
8383
- OLMo2
8484
- OWL-ViT
85+
- PatchTST
86+
- PatchTSMixer
8587
- Pegasus
8688
- Perceiver
8789
- Phi

optimum/exporters/onnx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class OnnxConfig(ExportConfig, ABC):
180180
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
181181
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
182182
"text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
183+
"time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}),
183184
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
184185
"visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
185186
"zero-shot-image-classification": OrderedDict(

optimum/exporters/onnx/model_configs.py

+23
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
DummyInputGenerator,
3535
DummyIntGenerator,
3636
DummyPastKeyValuesGenerator,
37+
DummyPatchTSTInputGenerator,
3738
DummyPix2StructInputGenerator,
3839
DummyPointsGenerator,
3940
DummySeq2SeqDecoderTextInputGenerator,
@@ -58,6 +59,7 @@
5859
NormalizedTextAndVisionConfig,
5960
NormalizedTextConfig,
6061
NormalizedTextConfigWithGQA,
62+
NormalizedTimeSeriesForecastingConfig,
6163
NormalizedVisionConfig,
6264
is_diffusers_available,
6365
is_diffusers_version,
@@ -2619,3 +2621,24 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
26192621
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
26202622

26212623
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
2624+
2625+
2626+
class PatchTSTOnnxConfig(OnnxConfig):
2627+
NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig
2628+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyPatchTSTInputGenerator,)
2629+
ATOL_FOR_VALIDATION = 1e-4
2630+
2631+
@property
2632+
def inputs(self) -> Dict[str, Dict[int, str]]:
2633+
return {"past_values": {0: "batch_size", 1: "sequence_length"}}
2634+
2635+
@property
2636+
def outputs(self) -> Dict[str, Dict[int, str]]:
2637+
if self.task == "feature-extraction":
2638+
return {"last_hidden_state": {0: "batch_size"}}
2639+
else:
2640+
return super().outputs
2641+
2642+
2643+
class PatchTSMixerOnnxConfig(PatchTSTOnnxConfig):
2644+
pass

optimum/exporters/onnx/model_patcher.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,53 @@ class PatchingSpec:
113113
op_wrapper: Optional[Callable] = None
114114

115115

116+
# An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
117+
# torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
118+
# See https://github.com/pytorch/pytorch/issues/81871 for more information
119+
def onnx_compatible_unfold(input_tensor, dimension, size, step):
120+
"""
121+
Custom implementation of torch.unfold without using torch.unfold.
122+
123+
Args:
124+
input_tensor (torch.Tensor): The input tensor.
125+
dimension (int): The dimension to unfold.
126+
size (int): The size of each slice.
127+
step (int): The step size between slices.
128+
129+
Returns:
130+
torch.Tensor: The unfolded tensor.
131+
"""
132+
# Check if dimension is within the valid range
133+
if not (-input_tensor.dim() <= dimension < input_tensor.dim()):
134+
raise ValueError(
135+
f"Dimension out of range (expected to be in range of [{-input_tensor.dim()}, {input_tensor.dim() - 1}], but got {dimension})"
136+
)
137+
138+
# Normalize negative dimension
139+
dimension = dimension % input_tensor.dim()
140+
141+
# Compute the shape of the unfolded output
142+
input_size = input_tensor.size(dimension)
143+
num_slices = (input_size - size) // step + 1
144+
145+
# Permute dimension to the end for easier indexing
146+
input_tensor = input_tensor.transpose(dimension, -1)
147+
148+
# Extract slices
149+
slices = []
150+
for i in range(num_slices):
151+
start = i * step
152+
end = start + size
153+
slices.append(input_tensor[..., start:end])
154+
155+
# Stack slices and permute dimensions back
156+
result = torch.stack(slices, dim=-2).transpose(dimension, -2)
157+
return result
158+
159+
160+
UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
161+
162+
116163
class ModelPatcher:
117164
def __init__(
118165
self,
@@ -122,9 +169,11 @@ def __init__(
122169
):
123170
self._model = model
124171

125-
patching_specs = config.PATCHING_SPECS
172+
patching_specs = config.PATCHING_SPECS or []
173+
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)
174+
126175
self._patching_specs = []
127-
for spec in patching_specs if patching_specs is not None else []:
176+
for spec in patching_specs:
128177
final_spec = spec
129178
if spec.orig_op is None:
130179
final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))

optimum/exporters/tasks.py

+12
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ class TasksManager:
322322
}
323323

324324
_CUSTOM_CLASSES = {
325+
("pt", "patchtsmixer", "time-series-forecasting"): ("transformers", "PatchTSMixerForPrediction"),
326+
("pt", "patchtst", "time-series-forecasting"): ("transformers", "PatchTSTForPrediction"),
325327
("pt", "pix2struct", "image-to-text"): ("transformers", "Pix2StructForConditionalGeneration"),
326328
("pt", "pix2struct", "visual-question-answering"): ("transformers", "Pix2StructForConditionalGeneration"),
327329
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
@@ -962,6 +964,16 @@ class TasksManager:
962964
"text-classification",
963965
onnx="OPTOnnxConfig",
964966
),
967+
"patchtst": supported_tasks_mapping(
968+
"feature-extraction",
969+
"time-series-forecasting",
970+
onnx="PatchTSTOnnxConfig",
971+
),
972+
"patchtsmixer": supported_tasks_mapping(
973+
"feature-extraction",
974+
"time-series-forecasting",
975+
onnx="PatchTSMixerOnnxConfig",
976+
),
965977
"qwen2": supported_tasks_mapping(
966978
"feature-extraction",
967979
"feature-extraction-with-past",

optimum/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
DummyIntGenerator,
7070
DummyLabelsGenerator,
7171
DummyPastKeyValuesGenerator,
72+
DummyPatchTSTInputGenerator,
7273
DummyPix2StructInputGenerator,
7374
DummyPointsGenerator,
7475
DummySeq2SeqDecoderTextInputGenerator,
@@ -98,5 +99,6 @@
9899
NormalizedTextAndVisionConfig,
99100
NormalizedTextConfig,
100101
NormalizedTextConfigWithGQA,
102+
NormalizedTimeSeriesForecastingConfig,
101103
NormalizedVisionConfig,
102104
)

optimum/utils/input_generators.py

+27
Original file line numberDiff line numberDiff line change
@@ -1532,3 +1532,30 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
15321532
return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype)
15331533

15341534
return super().generate(input_name, framework, int_dtype, float_dtype)
1535+
1536+
1537+
class DummyPatchTSTInputGenerator(DummyInputGenerator):
1538+
SUPPORTED_INPUT_NAMES = ("past_values",)
1539+
1540+
def __init__(
1541+
self,
1542+
task: str,
1543+
normalized_config: NormalizedConfig,
1544+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1545+
**kwargs,
1546+
):
1547+
self.task = task
1548+
self.normalized_config = normalized_config
1549+
1550+
self.batch_size = batch_size
1551+
self.context_length = normalized_config.context_length
1552+
self.num_input_channels = normalized_config.num_input_channels
1553+
1554+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1555+
return self.random_float_tensor(
1556+
shape=[self.batch_size, self.context_length, self.num_input_channels],
1557+
min_value=-1,
1558+
max_value=1,
1559+
framework=framework,
1560+
dtype=float_dtype,
1561+
)

optimum/utils/normalized_config.py

+5
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def has_attribute(self, attr_name):
7777
return True
7878

7979

80+
class NormalizedTimeSeriesForecastingConfig(NormalizedConfig):
81+
NUM_INPUT_CHANNELS = "num_input_channels"
82+
CONTEXT_LENGTH = "context_length"
83+
84+
8085
class NormalizedTextConfig(NormalizedConfig):
8186
VOCAB_SIZE = "vocab_size"
8287
HIDDEN_SIZE = "hidden_size"

tests/exporters/exporters_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@
136136
"opt": "hf-internal-testing/tiny-random-OPTModel",
137137
"owlv2": "hf-internal-testing/tiny-random-Owlv2Model",
138138
"owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel",
139+
"patchtst": "ibm/test-patchtst",
140+
"patchtsmixer": "ibm/test-patchtsmixer",
139141
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
140142
"perceiver": {
141143
"hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],

0 commit comments

Comments
 (0)