Skip to content

Commit 97308d2

Browse files
committed
fix sana
1 parent 047b43d commit 97308d2

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

optimum/intel/openvino/modeling_diffusion.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,7 @@ def reshape(
889889
)
890890

891891
if self.text_encoder_3 is not None:
892-
self.text_encoder_3.model = self._reshape_text_encoder(
893-
self.text_encoder_3.model, batch_size, getattr(self.tokenizer_3, "model_max_length", -1)
894-
)
892+
self.text_encoder_3.model = self._reshape_text_encoder(self.text_encoder_3.model, batch_size, -1)
895893

896894
self.clear_requests()
897895
return self
@@ -962,7 +960,7 @@ def components(self) -> Dict[str, Any]:
962960
components = {k: v for k, v in components.items() if v is not None}
963961
return components
964962

965-
def __call__(self, *args, height=None, width=None, **kwargs):
963+
def __call__(self, *args, **kwargs):
966964
# we do this to keep numpy random states support for now
967965
# TODO: deprecate and add warnings when a random state is passed
968966

@@ -973,23 +971,62 @@ def __call__(self, *args, height=None, width=None, **kwargs):
973971
for k, v in kwargs.items():
974972
kwargs[k] = np_to_pt_generators(v, self.device)
975973

974+
height, width = None, None
975+
height_idx, width_idx = None, None
976+
shapes_overriden = False
977+
sig = inspect.signature(self.auto_model_class.__call__)
978+
sig_height_idx = list(sig.parameters).index("height")
979+
sig_width_idx = list(sig.parameters).index("width")
980+
if "height" in kwargs:
981+
height = kwargs["height"]
982+
elif len(args) > sig_height_idx:
983+
height = args[sig_height_idx]
984+
height_idx = sig_height_idx
985+
986+
if "width" in kwargs:
987+
width = kwargs["width"]
988+
elif len(args) > sig_width_idx:
989+
width = args[sig_width_idx]
990+
width_idx = sig_width_idx
991+
976992
if self.height != -1:
977993
if height is not None and height != self.height:
978994
logger.warning(f"Incompatible height argument provided {height}. Pipeline only support {self.height}.")
979995
height = self.height
980996
else:
981997
height = self.height
982998

999+
if height_idx is not None:
1000+
args[height_idx] = height
1001+
else:
1002+
kwargs["height"] = height
1003+
1004+
shapes_overriden = True
1005+
9831006
if self.width != -1:
9841007
if width is not None and width != self.width:
9851008
logger.warning(f"Incompatible widtth argument provided {width}. Pipeline only support {self.width}.")
9861009
width = self.width
9871010
else:
9881011
width = self.width
9891012

1013+
if width_idx is not None:
1014+
args[width_idx] = width
1015+
else:
1016+
kwargs["width"] = width
1017+
shapes_overriden = True
1018+
1019+
# Sana generates images in specific resolution grid size and then resize to requested size by default, it may contradict with pipeline height / width
1020+
# Disable this behavior for static shape pipeline
1021+
if self.auto_model_class.__name__.startswith("Sana") and shapes_overriden:
1022+
sig_resolution_bining_idx = list(sig.parameters).index("use_resolution_binning")
1023+
if len(args) > sig_resolution_bining_idx:
1024+
args[sig_resolution_bining_idx] = False
1025+
else:
1026+
kwargs["use_resolution_binning"] = False
9901027
# we use auto_model_class.__call__ here because we can't call super().__call__
9911028
# as OptimizedModel already defines a __call__ which is the first in the MRO
992-
return self.auto_model_class.__call__(self, *args, height=height, width=width, **kwargs)
1029+
return self.auto_model_class.__call__(self, *args, **kwargs)
9931030

9941031

9951032
class OVPipelinePart(ConfigMixin):

tests/openvino/test_diffusion.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,19 @@ def test_load_custom_weight_variant(self):
442442
@require_diffusers
443443
def test_static_shape_image_generation(self, model_arch):
444444
pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], compile=False)
445-
pipeline.reshape(batch_size=-1, height=40, width=32)
445+
pipeline.reshape(batch_size=1, height=64, width=32)
446446
pipeline.compile()
447447
# generation with incompatible size
448448
height, width, batch_size = 64, 64, 1
449449
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
450-
image = pipeline(**inputs, num_inference_steps=2).images[0]
451-
self.assertTupleEqual(image.size, (32, 40))
450+
inputs["output_type"] = "pil"
451+
image = pipeline(**inputs).images[0]
452+
self.assertTupleEqual(image.size, (32, 64))
452453
# generation without height / width provided
453454
inputs.pop("height")
454455
inputs.pop("width")
455-
image = pipeline(**inputs, num_inference_steps=2).images[0]
456-
self.assertTupleEqual(image.size, (32, 40))
456+
image = pipeline(**inputs).images[0]
457+
self.assertTupleEqual(image.size, (32, 64))
457458

458459

459460
class OVPipelineForImage2ImageTest(unittest.TestCase):

0 commit comments

Comments
 (0)