Skip to content

Commit 02c6ed5

Browse files
authored
Make stable diffusion unet and vae number of channels static (#1840)
1 parent b3ecb6c commit 02c6ed5

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

optimum/exporters/onnx/model_configs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ class UNetOnnxConfig(VisionOnnxConfig):
981981
@property
982982
def inputs(self) -> Dict[str, Dict[int, str]]:
983983
common_inputs = {
984-
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
984+
"sample": {0: "batch_size", 2: "height", 3: "width"},
985985
"timestep": {0: "steps"},
986986
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
987987
}
@@ -998,7 +998,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
998998
@property
999999
def outputs(self) -> Dict[str, Dict[int, str]]:
10001000
return {
1001-
"out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
1001+
"out_sample": {0: "batch_size", 2: "height", 3: "width"},
10021002
}
10031003

10041004
@property
@@ -1045,13 +1045,13 @@ class VaeEncoderOnnxConfig(VisionOnnxConfig):
10451045
@property
10461046
def inputs(self) -> Dict[str, Dict[int, str]]:
10471047
return {
1048-
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
1048+
"sample": {0: "batch_size", 2: "height", 3: "width"},
10491049
}
10501050

10511051
@property
10521052
def outputs(self) -> Dict[str, Dict[int, str]]:
10531053
return {
1054-
"latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"},
1054+
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
10551055
}
10561056

10571057

@@ -1069,13 +1069,13 @@ class VaeDecoderOnnxConfig(VisionOnnxConfig):
10691069
@property
10701070
def inputs(self) -> Dict[str, Dict[int, str]]:
10711071
return {
1072-
"latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"},
1072+
"latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
10731073
}
10741074

10751075
@property
10761076
def outputs(self) -> Dict[str, Dict[int, str]]:
10771077
return {
1078-
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
1078+
"sample": {0: "batch_size", 2: "height", 3: "width"},
10791079
}
10801080

10811081

0 commit comments

Comments
 (0)