Skip to content

Commit 3cf8894

Browse files
committed
fix tests
1 parent 308c3ff commit 3cf8894

File tree

4 files changed

+65
-26
lines changed

4 files changed

+65
-26
lines changed

optimum/exporters/openvino/convert.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,11 @@ def get_diffusion_models_for_export_ext(
10081008
task="feature-extraction",
10091009
model_type="t5-encoder-model",
10101010
)
1011-
export_config = export_config_constructor(text_encoder_3.config, int_dtype=int_dtype, float_dtype=float_dtype, )
1011+
export_config = export_config_constructor(
1012+
text_encoder_3.config,
1013+
int_dtype=int_dtype,
1014+
float_dtype=float_dtype,
1015+
)
10121016
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)
10131017

10141018
return None, models_for_export

optimum/exporters/openvino/model_configs.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from packaging import version
2121
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2222
from transformers.utils import is_tf_available
23-
from optimum.exporters.onnx.model_patcher import ModelPatcher
2423

2524
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2625
from optimum.exporters.onnx.model_configs import (
@@ -38,6 +37,7 @@
3837
UNetOnnxConfig,
3938
VisionOnnxConfig,
4039
)
40+
from optimum.exporters.onnx.model_patcher import ModelPatcher
4141
from optimum.exporters.tasks import TasksManager
4242
from optimum.utils import DEFAULT_DUMMY_SHAPES
4343
from optimum.utils.input_generators import (
@@ -1583,4 +1583,3 @@ def patch_model_for_export(
15831583
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
15841584
) -> ModelPatcher:
15851585
return ModelPatcher(self, model, model_kwargs=model_kwargs)
1586-

optimum/intel/openvino/modeling_diffusion.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def width(self) -> int:
576576
@property
577577
def batch_size(self) -> int:
578578
model = self.unet.model if self.unet is not None else self.transformer.model
579-
batch_size = model.inputs[0].get_partial_shape()[0]
579+
batch_size = model.inputs[0].get_partial_shape()[0]
580580
if batch_size.is_dynamic:
581581
return -1
582582
return batch_size.get_length()
@@ -642,7 +642,7 @@ def _reshape_transformer(
642642
else:
643643
# The factor of 2 comes from the guidance scale > 1
644644
batch_size *= 2 * num_images_per_prompt
645-
645+
646646
height = height // self.vae_scale_factor if height > 0 else height
647647
width = width // self.vae_scale_factor if width > 0 else width
648648
shapes = {}
@@ -954,7 +954,9 @@ def modules(self):
954954
class OVModelTextEncoder(OVPipelinePart):
955955
def __init__(self, model: openvino.runtime.Model, parent_pipeline: OVDiffusionPipeline, model_name: str = ""):
956956
super().__init__(model, parent_pipeline, model_name)
957-
self.hidden_states_output_names = sorted({name for out in self.model.outputs for name in out.names if name.startswith("hidden_states")})
957+
self.hidden_states_output_names = sorted(
958+
{name for out in self.model.outputs for name in out.names if name.startswith("hidden_states")}
959+
)
958960

959961
def forward(
960962
self,
@@ -971,9 +973,13 @@ def forward(
971973
main_out = ov_outputs[0]
972974
model_outputs = {}
973975
model_outputs[self.model.outputs[0].get_any_name()] = torch.from_numpy(main_out)
974-
if self.hidden_states_output_names and not "last_hidden_state" in model_outputs:
976+
if self.hidden_states_output_names and "last_hidden_state" not in model_outputs:
975977
model_outputs["last_hidden_state"] = torch.from_numpy(ov_outputs[self.hidden_states_output_names[-1]])
976-
if self.hidden_states_output_names and output_hidden_states or self.config.output_hidden_states:
978+
if (
979+
self.hidden_states_output_names
980+
and output_hidden_states
981+
or getattr(self.config, "output_hidden_states", False)
982+
):
977983
hidden_states = [torch.from_numpy(ov_outputs[out_name]) for out_name in self.hidden_states_output_names]
978984
model_outputs["hidden_states"] = hidden_states
979985

tests/openvino/test_diffusion.py

+48-18
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,19 @@ def test_shape(self, model_arch: str):
185185
elif output_type == "pt":
186186
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
187187
else:
188-
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
188+
out_channels = (
189+
pipeline.unet.config.out_channels
190+
if pipeline.unet is not None
191+
else pipeline.transformer.config.out_channels
192+
)
189193
self.assertEqual(
190194
outputs.shape,
191-
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
195+
(
196+
batch_size,
197+
out_channels,
198+
height // pipeline.vae_scale_factor,
199+
width // pipeline.vae_scale_factor,
200+
),
192201
)
193202

194203
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -246,7 +255,7 @@ def test_negative_prompt(self, model_arch: str):
246255
do_classifier_free_guidance=True,
247256
negative_prompt=negative_prompt,
248257
)
249-
258+
250259
else:
251260
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = pipeline.encode_prompt(
252261
prompt=prompt,
@@ -306,8 +315,10 @@ def test_height_width_properties(self, model_arch: str):
306315
)
307316

308317
self.assertFalse(ov_pipeline.is_dynamic)
309-
expected_batch = batch_size * num_images_per_prompt
310-
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
318+
expected_batch = batch_size * num_images_per_prompt
319+
if ov_pipeline.unet is not None and "timestep_cond" not in {
320+
inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
321+
}:
311322
expected_batch *= 2
312323
self.assertEqual(
313324
ov_pipeline.batch_size,
@@ -435,10 +446,19 @@ def test_shape(self, model_arch: str):
435446
elif output_type == "pt":
436447
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
437448
else:
438-
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
449+
out_channels = (
450+
pipeline.unet.config.out_channels
451+
if pipeline.unet is not None
452+
else pipeline.transformer.config.out_channels
453+
)
439454
self.assertEqual(
440455
outputs.shape,
441-
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
456+
(
457+
batch_size,
458+
out_channels,
459+
height // pipeline.vae_scale_factor,
460+
width // pipeline.vae_scale_factor,
461+
),
442462
)
443463

444464
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -521,13 +541,12 @@ def test_height_width_properties(self, model_arch: str):
521541
)
522542

523543
self.assertFalse(ov_pipeline.is_dynamic)
524-
expected_batch = batch_size * num_images_per_prompt
525-
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
544+
expected_batch = batch_size * num_images_per_prompt
545+
if ov_pipeline.unet is not None and "timestep_cond" not in {
546+
inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
547+
}:
526548
expected_batch *= 2
527-
self.assertEqual(
528-
ov_pipeline.batch_size,
529-
expected_batch
530-
)
549+
self.assertEqual(ov_pipeline.batch_size, expected_batch)
531550
self.assertEqual(ov_pipeline.height, height)
532551
self.assertEqual(ov_pipeline.width, width)
533552

@@ -655,10 +674,19 @@ def test_shape(self, model_arch: str):
655674
elif output_type == "pt":
656675
self.assertEqual(outputs.shape, (batch_size, 3, height, width))
657676
else:
658-
out_channels = pipeline.unet.config.out_channels if pipeline.unet is not None else pipeline.transformer.config.out_channels
677+
out_channels = (
678+
pipeline.unet.config.out_channels
679+
if pipeline.unet is not None
680+
else pipeline.transformer.config.out_channels
681+
)
659682
self.assertEqual(
660683
outputs.shape,
661-
(batch_size, out_channels, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
684+
(
685+
batch_size,
686+
out_channels,
687+
height // pipeline.vae_scale_factor,
688+
width // pipeline.vae_scale_factor,
689+
),
662690
)
663691

664692
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -676,7 +704,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
676704

677705
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
678706
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
679-
707+
680708
np.testing.assert_allclose(ov_output, diffusers_output, atol=6e-3, rtol=1e-2)
681709

682710
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -741,8 +769,10 @@ def test_height_width_properties(self, model_arch: str):
741769
)
742770

743771
self.assertFalse(ov_pipeline.is_dynamic)
744-
expected_batch = batch_size * num_images_per_prompt
745-
if ov_pipeline.unet is not None and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs}:
772+
expected_batch = batch_size * num_images_per_prompt
773+
if ov_pipeline.unet is not None and "timestep_cond" not in {
774+
inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs
775+
}:
746776
expected_batch *= 2
747777
self.assertEqual(
748778
ov_pipeline.batch_size,

0 commit comments

Comments
 (0)