13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
"""Model specific ONNX configurations."""
16
+
16
17
import random
17
18
from pathlib import Path
18
19
from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Tuple , Union
28
29
DummyCodegenDecoderTextInputGenerator ,
29
30
DummyDecoderTextInputGenerator ,
30
31
DummyEncodecInputGenerator ,
32
+ DummyFluxTransformerTextInputGenerator ,
33
+ DummyFluxTransformerVisionInputGenerator ,
31
34
DummyInputGenerator ,
32
35
DummyIntGenerator ,
33
36
DummyPastKeyValuesGenerator ,
38
41
DummySpeechT5InputGenerator ,
39
42
DummyTextInputGenerator ,
40
43
DummyTimestepInputGenerator ,
44
+ DummyTransformerTextInputGenerator ,
45
+ DummyTransformerTimestepInputGenerator ,
46
+ DummyTransformerVisionInputGenerator ,
41
47
DummyVisionEmbeddingsGenerator ,
42
48
DummyVisionEncoderDecoderPastKeyValuesGenerator ,
43
49
DummyVisionInputGenerator ,
53
59
NormalizedTextConfig ,
54
60
NormalizedTextConfigWithGQA ,
55
61
NormalizedVisionConfig ,
62
+ check_if_diffusers_greater ,
56
63
check_if_transformers_greater ,
57
64
is_diffusers_available ,
58
65
logging ,
@@ -1039,22 +1046,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
1039
1046
"last_hidden_state" : {0 : "batch_size" , 1 : "sequence_length" },
1040
1047
"pooler_output" : {0 : "batch_size" },
1041
1048
}
1049
+
1042
1050
if self ._normalized_config .output_hidden_states :
1043
1051
for i in range (self ._normalized_config .num_layers + 1 ):
1044
1052
common_outputs [f"hidden_states.{ i } " ] = {0 : "batch_size" , 1 : "sequence_length" }
1045
1053
1046
1054
return common_outputs
1047
1055
1048
- def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
1049
- dummy_inputs = super ().generate_dummy_inputs (framework = framework , ** kwargs )
1050
-
1051
- # TODO: fix should be by casting inputs during inference and not export
1052
- if framework == "pt" :
1053
- import torch
1054
-
1055
- dummy_inputs ["input_ids" ] = dummy_inputs ["input_ids" ].to (dtype = torch .int32 )
1056
- return dummy_inputs
1057
-
1058
1056
def patch_model_for_export (
1059
1057
self ,
1060
1058
model : Union ["PreTrainedModel" , "TFPreTrainedModel" , "ModelMixin" ],
@@ -1064,7 +1062,7 @@ def patch_model_for_export(
1064
1062
1065
1063
1066
1064
class UNetOnnxConfig (VisionOnnxConfig ):
1067
- ATOL_FOR_VALIDATION = 1e-3
1065
+ ATOL_FOR_VALIDATION = 1e-4
1068
1066
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
1069
1067
# operator support, available since opset 14
1070
1068
DEFAULT_ONNX_OPSET = 14
@@ -1087,17 +1085,19 @@ class UNetOnnxConfig(VisionOnnxConfig):
1087
1085
def inputs (self ) -> Dict [str , Dict [int , str ]]:
1088
1086
common_inputs = {
1089
1087
"sample" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
1090
- "timestep" : {0 : "steps" },
1088
+ "timestep" : {}, # a scalar with no dimension
1091
1089
"encoder_hidden_states" : {0 : "batch_size" , 1 : "sequence_length" },
1092
1090
}
1093
1091
1094
- # TODO : add text_image, image and image_embeds
1092
+ # TODO : add addition_embed_type == text_image, image and image_embeds
1093
+ # https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671
1095
1094
if getattr (self ._normalized_config , "addition_embed_type" , None ) == "text_time" :
1096
1095
common_inputs ["text_embeds" ] = {0 : "batch_size" }
1097
1096
common_inputs ["time_ids" ] = {0 : "batch_size" }
1098
1097
1099
1098
if getattr (self ._normalized_config , "time_cond_proj_dim" , None ) is not None :
1100
1099
common_inputs ["timestep_cond" ] = {0 : "batch_size" }
1100
+
1101
1101
return common_inputs
1102
1102
1103
1103
@property
@@ -1136,7 +1136,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
1136
1136
1137
1137
1138
1138
class VaeEncoderOnnxConfig (VisionOnnxConfig ):
1139
- ATOL_FOR_VALIDATION = 1e -4
1139
+ ATOL_FOR_VALIDATION = 3e -4
1140
1140
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
1141
1141
# operator support, available since opset 14
1142
1142
DEFAULT_ONNX_OPSET = 14
@@ -1184,6 +1184,101 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
1184
1184
}
1185
1185
1186
1186
1187
+ class T5EncoderOnnxConfig (TextEncoderOnnxConfig ):
1188
+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
1189
+ ATOL_FOR_VALIDATION = 1e-4
1190
+ DEFAULT_ONNX_OPSET = 12 # int64 was supported since opset 12
1191
+
1192
+ @property
1193
+ def inputs (self ):
1194
+ return {
1195
+ "input_ids" : {0 : "batch_size" , 1 : "sequence_length" },
1196
+ }
1197
+
1198
+ @property
1199
+ def outputs (self ):
1200
+ return {
1201
+ "last_hidden_state" : {0 : "batch_size" , 1 : "sequence_length" },
1202
+ }
1203
+
1204
+
1205
+ class SD3TransformerOnnxConfig (VisionOnnxConfig ):
1206
+ ATOL_FOR_VALIDATION = 1e-4
1207
+ # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
1208
+ # operator support, available since opset 14
1209
+ DEFAULT_ONNX_OPSET = 14
1210
+
1211
+ DUMMY_INPUT_GENERATOR_CLASSES = (
1212
+ DummyTransformerTimestepInputGenerator ,
1213
+ DummyTransformerVisionInputGenerator ,
1214
+ DummyTransformerTextInputGenerator ,
1215
+ )
1216
+
1217
+ NORMALIZED_CONFIG_CLASS = NormalizedConfig .with_args (
1218
+ image_size = "sample_size" ,
1219
+ num_channels = "in_channels" ,
1220
+ vocab_size = "attention_head_dim" ,
1221
+ hidden_size = "joint_attention_dim" ,
1222
+ projection_size = "pooled_projection_dim" ,
1223
+ allow_new = True ,
1224
+ )
1225
+
1226
+ @property
1227
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
1228
+ common_inputs = {
1229
+ "hidden_states" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
1230
+ "encoder_hidden_states" : {0 : "batch_size" , 1 : "sequence_length" },
1231
+ "pooled_projections" : {0 : "batch_size" },
1232
+ "timestep" : {0 : "step" },
1233
+ }
1234
+
1235
+ return common_inputs
1236
+
1237
+ @property
1238
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
1239
+ return {
1240
+ "out_hidden_states" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
1241
+ }
1242
+
1243
+ @property
1244
+ def torch_to_onnx_output_map (self ) -> Dict [str , str ]:
1245
+ return {
1246
+ "sample" : "out_hidden_states" ,
1247
+ }
1248
+
1249
+
1250
+ class FluxTransformerOnnxConfig (SD3TransformerOnnxConfig ):
1251
+ DUMMY_INPUT_GENERATOR_CLASSES = (
1252
+ DummyTransformerTimestepInputGenerator ,
1253
+ DummyFluxTransformerVisionInputGenerator ,
1254
+ DummyFluxTransformerTextInputGenerator ,
1255
+ )
1256
+
1257
+ @property
1258
+ def inputs (self ):
1259
+ common_inputs = super ().inputs
1260
+ common_inputs ["hidden_states" ] = {0 : "batch_size" , 1 : "packed_height_width" }
1261
+ common_inputs ["txt_ids" ] = (
1262
+ {0 : "sequence_length" } if check_if_diffusers_greater ("0.31.0" ) else {0 : "batch_size" , 1 : "sequence_length" }
1263
+ )
1264
+ common_inputs ["img_ids" ] = (
1265
+ {0 : "packed_height_width" }
1266
+ if check_if_diffusers_greater ("0.31.0" )
1267
+ else {0 : "batch_size" , 1 : "packed_height_width" }
1268
+ )
1269
+
1270
+ if getattr (self ._normalized_config , "guidance_embeds" , False ):
1271
+ common_inputs ["guidance" ] = {0 : "batch_size" }
1272
+
1273
+ return common_inputs
1274
+
1275
+ @property
1276
+ def outputs (self ):
1277
+ return {
1278
+ "out_hidden_states" : {0 : "batch_size" , 1 : "packed_height_width" },
1279
+ }
1280
+
1281
+
1187
1282
class GroupViTOnnxConfig (CLIPOnnxConfig ):
1188
1283
pass
1189
1284
0 commit comments