@@ -218,8 +218,8 @@ def test_shape(self, model_arch: str):
218
218
),
219
219
)
220
220
else :
221
- packed_height = height // pipeline .vae_scale_factor
222
- packed_width = width // pipeline .vae_scale_factor
221
+ packed_height = height // pipeline .vae_scale_factor // 2
222
+ packed_width = width // pipeline .vae_scale_factor // 2
223
223
channels = pipeline .transformer .config .in_channels
224
224
self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
225
225
@@ -426,7 +426,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
426
426
height = height , width = width , batch_size = batch_size , channel = channel , input_type = input_type
427
427
)
428
428
429
- if "flux" == model_type :
429
+ if model_type in [ "flux" , "stable-diffusion-3" ] :
430
430
inputs ["height" ] = height
431
431
inputs ["width" ] = width
432
432
@@ -529,8 +529,8 @@ def test_shape(self, model_arch: str):
529
529
),
530
530
)
531
531
else :
532
- packed_height = height // pipeline .vae_scale_factor
533
- packed_width = width // pipeline .vae_scale_factor
532
+ packed_height = height // pipeline .vae_scale_factor // 2
533
+ packed_width = width // pipeline .vae_scale_factor // 2
534
534
channels = pipeline .transformer .config .in_channels
535
535
self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
536
536
@@ -780,8 +780,8 @@ def test_shape(self, model_arch: str):
780
780
),
781
781
)
782
782
else :
783
- packed_height = height // pipeline .vae_scale_factor
784
- packed_width = width // pipeline .vae_scale_factor
783
+ packed_height = height // pipeline .vae_scale_factor // 2
784
+ packed_width = width // pipeline .vae_scale_factor // 2
785
785
channels = pipeline .transformer .config .in_channels
786
786
self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
787
787
0 commit comments