@@ -78,7 +78,7 @@ class OVPipelineForText2ImageTest(unittest.TestCase):
78
78
NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
79
79
if is_transformers_version (">=" , "4.40.0" ):
80
80
SUPPORTED_ARCHITECTURES .extend (["stable-diffusion-3" , "flux" , "sana" ])
81
- NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES .append (["stable-diffusion-3" ])
81
+ NEGATIVE_PROMPT_SUPPORT_ARCHITECTURES .extend (["stable-diffusion-3" ])
82
82
CALLBACK_SUPPORT_ARCHITECTURES = ["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]
83
83
84
84
OVMODEL_CLASS = OVPipelineForText2Image
@@ -94,13 +94,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
94
94
95
95
return inputs
96
96
97
- def get_auto_cls (self , model_arch ):
98
- if model_arch == "sana" :
99
- from diffusers import SanaPipeline
100
-
101
- return SanaPipeline
102
- return self .AUTOMODEL_CLASS
103
-
104
97
@require_diffusers
105
98
def test_load_vanilla_model_which_is_not_supported (self ):
106
99
with self .assertRaises (Exception ) as context :
@@ -111,8 +104,7 @@ def test_load_vanilla_model_which_is_not_supported(self):
111
104
@parameterized .expand (SUPPORTED_ARCHITECTURES )
112
105
@require_diffusers
113
106
def test_ov_pipeline_class_dispatch (self , model_arch : str ):
114
- auto_cls = self .get_auto_cls (model_arch )
115
- auto_pipeline = DiffusionPipeline if model_arch != "sana" else auto_cls
107
+ auto_pipeline = DiffusionPipeline
116
108
auto_pipeline = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
117
109
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
118
110
@@ -141,21 +133,19 @@ def test_num_images_per_prompt(self, model_arch: str):
141
133
def test_compare_to_diffusers_pipeline (self , model_arch : str ):
142
134
height , width , batch_size = 64 , 64 , 1
143
135
inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
144
- auto_cls = self .get_auto_cls (model_arch )
145
136
ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
146
- diffusers_pipeline = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
137
+ diffusers_pipeline = DiffusionPipeline .from_pretrained (MODEL_NAMES [model_arch ])
147
138
148
- with torch .no_grad ():
149
- for output_type in ["latent" , "np" , "pt" ]:
150
- inputs ["output_type" ] = output_type
151
- if model_arch == "sana" :
152
- # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
153
- inputs ["use_resolution_binning" ] = False
154
- atol = 1e-4
139
+ for output_type in ["latent" , "np" , "pt" ]:
140
+ inputs ["output_type" ] = output_type
141
+ if model_arch == "sana" :
142
+ # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
143
+ inputs ["use_resolution_binning" ] = False
144
+ atol = 1e-4
155
145
156
- ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
157
- diffusers_output = diffusers_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
158
- np .testing .assert_allclose (ov_output , diffusers_output , atol = atol , rtol = 1e-2 )
146
+ ov_output = ov_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
147
+ diffusers_output = diffusers_pipeline (** inputs , generator = get_generator ("pt" , SEED )).images
148
+ np .testing .assert_allclose (ov_output , diffusers_output , atol = atol , rtol = 1e-2 )
159
149
160
150
# test on inputs nondivisible on 64
161
151
height , width , batch_size = 96 , 96 , 1
@@ -191,8 +181,7 @@ def __call__(self, *args, **kwargs) -> None:
191
181
auto_callback = Callback ()
192
182
193
183
ov_pipe = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
194
- auto_cls = self .get_auto_cls (model_arch )
195
- auto_pipe = auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
184
+ auto_pipe = DiffusionPipeline .from_pretrained (MODEL_NAMES [model_arch ])
196
185
197
186
# callback_steps=1 to trigger callback every step
198
187
ov_pipe (** inputs , callback = ov_callback , callback_steps = 1 )
0 commit comments