Skip to content

Commit a7a807c

Browse files
SD3 and Flux support (#2073)
* sd3 support * unsupported cli model types * flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests * fixes * move input generators * dummy diffusers * style * sd3 support * unsupported cli model types * flux transformer support, unet export fixes, updated callback test, updated negative prompt test, flux and sd3 tests * fixes * move input generators * dummy diffusers * style * distribute ort tests * fix * fix * fix * test num images * single process to reduce re-exports * test * revert unnecessary changes * T5Encoder inherits from TextEncoder * style * fix typo in timestep * style * only test sd3 and flux on latest transformers * conditional sd3 and flux modeling * forgot sd3 inpaint
1 parent 400bb82 commit a7a807c

18 files changed

+791
-217
lines changed

.github/workflows/test_onnxruntime.yml

+4-9
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,11 @@ jobs:
2626
os: ubuntu-20.04
2727

2828
runs-on: ${{ matrix.os }}
29+
2930
steps:
3031
- name: Free Disk Space (Ubuntu)
3132
if: matrix.os == 'ubuntu-20.04'
3233
uses: jlumbroso/free-disk-space@main
33-
with:
34-
tool-cache: false
35-
swap-storage: false
36-
large-packages: false
3734

3835
- name: Checkout code
3936
uses: actions/checkout@v4
@@ -54,13 +51,11 @@ jobs:
5451
run: pip install transformers==${{ matrix.transformers-version }}
5552

5653
- name: Test with pytest (in series)
57-
working-directory: tests
5854
run: |
59-
pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s
55+
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv -s
6056
6157
- name: Test with pytest (in parallel)
58+
run: |
59+
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
6260
env:
6361
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
64-
working-directory: tests
65-
run: |
66-
pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto

optimum/exporters/onnx/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def fix_dynamic_axes(
319319
input_shapes = {}
320320
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
321321
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)
322+
dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs)
322323

323324
onnx_inputs = {}
324325
for name, value in dummy_inputs.items():

optimum/exporters/onnx/convert.py

+4
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,10 @@ def onnx_export_from_model(
11831183
if tokenizer_2 is not None:
11841184
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
11851185

1186+
tokenizer_3 = getattr(model, "tokenizer_3", None)
1187+
if tokenizer_3 is not None:
1188+
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
1189+
11861190
model.save_config(output)
11871191

11881192
if float_dtype == "bf16":

optimum/exporters/onnx/model_configs.py

+109-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""Model specific ONNX configurations."""
16+
1617
import random
1718
from pathlib import Path
1819
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
@@ -28,6 +29,8 @@
2829
DummyCodegenDecoderTextInputGenerator,
2930
DummyDecoderTextInputGenerator,
3031
DummyEncodecInputGenerator,
32+
DummyFluxTransformerTextInputGenerator,
33+
DummyFluxTransformerVisionInputGenerator,
3134
DummyInputGenerator,
3235
DummyIntGenerator,
3336
DummyPastKeyValuesGenerator,
@@ -38,6 +41,9 @@
3841
DummySpeechT5InputGenerator,
3942
DummyTextInputGenerator,
4043
DummyTimestepInputGenerator,
44+
DummyTransformerTextInputGenerator,
45+
DummyTransformerTimestepInputGenerator,
46+
DummyTransformerVisionInputGenerator,
4147
DummyVisionEmbeddingsGenerator,
4248
DummyVisionEncoderDecoderPastKeyValuesGenerator,
4349
DummyVisionInputGenerator,
@@ -53,6 +59,7 @@
5359
NormalizedTextConfig,
5460
NormalizedTextConfigWithGQA,
5561
NormalizedVisionConfig,
62+
check_if_diffusers_greater,
5663
check_if_transformers_greater,
5764
is_diffusers_available,
5865
logging,
@@ -1039,22 +1046,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
10391046
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
10401047
"pooler_output": {0: "batch_size"},
10411048
}
1049+
10421050
if self._normalized_config.output_hidden_states:
10431051
for i in range(self._normalized_config.num_layers + 1):
10441052
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
10451053

10461054
return common_outputs
10471055

1048-
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
1049-
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
1050-
1051-
# TODO: fix should be by casting inputs during inference and not export
1052-
if framework == "pt":
1053-
import torch
1054-
1055-
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
1056-
return dummy_inputs
1057-
10581056
def patch_model_for_export(
10591057
self,
10601058
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
@@ -1064,7 +1062,7 @@ def patch_model_for_export(
10641062

10651063

10661064
class UNetOnnxConfig(VisionOnnxConfig):
1067-
ATOL_FOR_VALIDATION = 1e-3
1065+
ATOL_FOR_VALIDATION = 1e-4
10681066
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
10691067
# operator support, available since opset 14
10701068
DEFAULT_ONNX_OPSET = 14
@@ -1087,17 +1085,19 @@ class UNetOnnxConfig(VisionOnnxConfig):
10871085
def inputs(self) -> Dict[str, Dict[int, str]]:
10881086
common_inputs = {
10891087
"sample": {0: "batch_size", 2: "height", 3: "width"},
1090-
"timestep": {0: "steps"},
1088+
"timestep": {}, # a scalar with no dimension
10911089
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
10921090
}
10931091

1094-
# TODO : add text_image, image and image_embeds
1092+
# TODO : add addition_embed_type == text_image, image and image_embeds
1093+
# https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671
10951094
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
10961095
common_inputs["text_embeds"] = {0: "batch_size"}
10971096
common_inputs["time_ids"] = {0: "batch_size"}
10981097

10991098
if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
11001099
common_inputs["timestep_cond"] = {0: "batch_size"}
1100+
11011101
return common_inputs
11021102

11031103
@property
@@ -1136,7 +1136,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
11361136

11371137

11381138
class VaeEncoderOnnxConfig(VisionOnnxConfig):
1139-
ATOL_FOR_VALIDATION = 1e-4
1139+
ATOL_FOR_VALIDATION = 3e-4
11401140
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
11411141
# operator support, available since opset 14
11421142
DEFAULT_ONNX_OPSET = 14
@@ -1184,6 +1184,101 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
11841184
}
11851185

11861186

1187+
class T5EncoderOnnxConfig(TextEncoderOnnxConfig):
1188+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
1189+
ATOL_FOR_VALIDATION = 1e-4
1190+
DEFAULT_ONNX_OPSET = 12 # int64 was supported since opset 12
1191+
1192+
@property
1193+
def inputs(self):
1194+
return {
1195+
"input_ids": {0: "batch_size", 1: "sequence_length"},
1196+
}
1197+
1198+
@property
1199+
def outputs(self):
1200+
return {
1201+
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
1202+
}
1203+
1204+
1205+
class SD3TransformerOnnxConfig(VisionOnnxConfig):
1206+
ATOL_FOR_VALIDATION = 1e-4
1207+
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
1208+
# operator support, available since opset 14
1209+
DEFAULT_ONNX_OPSET = 14
1210+
1211+
DUMMY_INPUT_GENERATOR_CLASSES = (
1212+
DummyTransformerTimestepInputGenerator,
1213+
DummyTransformerVisionInputGenerator,
1214+
DummyTransformerTextInputGenerator,
1215+
)
1216+
1217+
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
1218+
image_size="sample_size",
1219+
num_channels="in_channels",
1220+
vocab_size="attention_head_dim",
1221+
hidden_size="joint_attention_dim",
1222+
projection_size="pooled_projection_dim",
1223+
allow_new=True,
1224+
)
1225+
1226+
@property
1227+
def inputs(self) -> Dict[str, Dict[int, str]]:
1228+
common_inputs = {
1229+
"hidden_states": {0: "batch_size", 2: "height", 3: "width"},
1230+
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
1231+
"pooled_projections": {0: "batch_size"},
1232+
"timestep": {0: "step"},
1233+
}
1234+
1235+
return common_inputs
1236+
1237+
@property
1238+
def outputs(self) -> Dict[str, Dict[int, str]]:
1239+
return {
1240+
"out_hidden_states": {0: "batch_size", 2: "height", 3: "width"},
1241+
}
1242+
1243+
@property
1244+
def torch_to_onnx_output_map(self) -> Dict[str, str]:
1245+
return {
1246+
"sample": "out_hidden_states",
1247+
}
1248+
1249+
1250+
class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig):
1251+
DUMMY_INPUT_GENERATOR_CLASSES = (
1252+
DummyTransformerTimestepInputGenerator,
1253+
DummyFluxTransformerVisionInputGenerator,
1254+
DummyFluxTransformerTextInputGenerator,
1255+
)
1256+
1257+
@property
1258+
def inputs(self):
1259+
common_inputs = super().inputs
1260+
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
1261+
common_inputs["txt_ids"] = (
1262+
{0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"}
1263+
)
1264+
common_inputs["img_ids"] = (
1265+
{0: "packed_height_width"}
1266+
if check_if_diffusers_greater("0.31.0")
1267+
else {0: "batch_size", 1: "packed_height_width"}
1268+
)
1269+
1270+
if getattr(self._normalized_config, "guidance_embeds", False):
1271+
common_inputs["guidance"] = {0: "batch_size"}
1272+
1273+
return common_inputs
1274+
1275+
@property
1276+
def outputs(self):
1277+
return {
1278+
"out_hidden_states": {0: "batch_size", 1: "packed_height_width"},
1279+
}
1280+
1281+
11871282
class GroupViTOnnxConfig(CLIPOnnxConfig):
11881283
pass
11891284

optimum/exporters/tasks.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -335,15 +335,27 @@ class TasksManager:
335335
}
336336

337337
_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
338-
"clip-text-model": supported_tasks_mapping(
338+
"t5-encoder": supported_tasks_mapping(
339+
"feature-extraction",
340+
onnx="T5EncoderOnnxConfig",
341+
),
342+
"clip-text": supported_tasks_mapping(
339343
"feature-extraction",
340344
onnx="CLIPTextOnnxConfig",
341345
),
342346
"clip-text-with-projection": supported_tasks_mapping(
343347
"feature-extraction",
344348
onnx="CLIPTextWithProjectionOnnxConfig",
345349
),
346-
"unet": supported_tasks_mapping(
350+
"flux-transformer-2d": supported_tasks_mapping(
351+
"semantic-segmentation",
352+
onnx="FluxTransformerOnnxConfig",
353+
),
354+
"sd3-transformer-2d": supported_tasks_mapping(
355+
"semantic-segmentation",
356+
onnx="SD3TransformerOnnxConfig",
357+
),
358+
"unet-2d-condition": supported_tasks_mapping(
347359
"semantic-segmentation",
348360
onnx="UNetOnnxConfig",
349361
),
@@ -1177,12 +1189,17 @@ class TasksManager:
11771189
"transformers": _SUPPORTED_MODEL_TYPE,
11781190
}
11791191
_UNSUPPORTED_CLI_MODEL_TYPE = {
1180-
"unet",
1192+
# diffusers model types
1193+
"clip-text",
1194+
"clip-text-with-projection",
1195+
"flux-transformer-2d",
1196+
"sd3-transformer-2d",
1197+
"t5-encoder",
1198+
"unet-2d-condition",
11811199
"vae-encoder",
11821200
"vae-decoder",
1183-
"clip-text-model",
1184-
"clip-text-with-projection",
1185-
"trocr", # supported through the vision-encoder-decoder model type
1201+
# redundant model types
1202+
"trocr", # same as vision-encoder-decoder
11861203
}
11871204
_SUPPORTED_CLI_MODEL_TYPE = (
11881205
set(_SUPPORTED_MODEL_TYPE.keys())

0 commit comments

Comments
 (0)