@@ -889,9 +889,7 @@ def reshape(
889
889
)
890
890
891
891
if self .text_encoder_3 is not None :
892
- self .text_encoder_3 .model = self ._reshape_text_encoder (
893
- self .text_encoder_3 .model , batch_size , getattr (self .tokenizer_3 , "model_max_length" , - 1 )
894
- )
892
+ self .text_encoder_3 .model = self ._reshape_text_encoder (self .text_encoder_3 .model , batch_size , - 1 )
895
893
896
894
self .clear_requests ()
897
895
return self
@@ -962,7 +960,7 @@ def components(self) -> Dict[str, Any]:
962
960
components = {k : v for k , v in components .items () if v is not None }
963
961
return components
964
962
965
- def __call__ (self , * args , height = None , width = None , ** kwargs ):
963
+ def __call__ (self , * args , ** kwargs ):
966
964
# we do this to keep numpy random states support for now
967
965
# TODO: deprecate and add warnings when a random state is passed
968
966
@@ -973,23 +971,62 @@ def __call__(self, *args, height=None, width=None, **kwargs):
973
971
for k , v in kwargs .items ():
974
972
kwargs [k ] = np_to_pt_generators (v , self .device )
975
973
974
+ height , width = None , None
975
+ height_idx , width_idx = None , None
976
+ shapes_overriden = False
977
+ sig = inspect .signature (self .auto_model_class .__call__ )
978
+ sig_height_idx = list (sig .parameters ).index ("height" )
979
+ sig_width_idx = list (sig .parameters ).index ("width" )
980
+ if "height" in kwargs :
981
+ height = kwargs ["height" ]
982
+ elif len (args ) > sig_height_idx :
983
+ height = args [sig_height_idx ]
984
+ height_idx = sig_height_idx
985
+
986
+ if "width" in kwargs :
987
+ width = kwargs ["width" ]
988
+ elif len (args ) > sig_width_idx :
989
+ width = args [sig_width_idx ]
990
+ width_idx = sig_width_idx
991
+
976
992
if self .height != - 1 :
977
993
if height is not None and height != self .height :
978
994
logger .warning (f"Incompatible height argument provided { height } . Pipeline only support { self .height } ." )
979
995
height = self .height
980
996
else :
981
997
height = self .height
982
998
999
+ if height_idx is not None :
1000
+ args [height_idx ] = height
1001
+ else :
1002
+ kwargs ["height" ] = height
1003
+
1004
+ shapes_overriden = True
1005
+
983
1006
if self .width != - 1 :
984
1007
if width is not None and width != self .width :
985
1008
logger .warning (f"Incompatible widtth argument provided { width } . Pipeline only support { self .width } ." )
986
1009
width = self .width
987
1010
else :
988
1011
width = self .width
989
1012
1013
+ if width_idx is not None :
1014
+ args [width_idx ] = width
1015
+ else :
1016
+ kwargs ["width" ] = width
1017
+ shapes_overriden = True
1018
+
1019
+ # Sana generates images in specific resolution grid size and then resize to requested size by default, it may contradict with pipeline height / width
1020
+ # Disable this behavior for static shape pipeline
1021
+ if self .auto_model_class .__name__ .startswith ("Sana" ) and shapes_overriden :
1022
+ sig_resolution_bining_idx = list (sig .parameters ).index ("use_resolution_binning" )
1023
+ if len (args ) > sig_resolution_bining_idx :
1024
+ args [sig_resolution_bining_idx ] = False
1025
+ else :
1026
+ kwargs ["use_resolution_binning" ] = False
990
1027
# we use auto_model_class.__call__ here because we can't call super().__call__
991
1028
# as OptimizedModel already defines a __call__ which is the first in the MRO
992
- return self .auto_model_class .__call__ (self , * args , height = height , width = width , ** kwargs )
1029
+ return self .auto_model_class .__call__ (self , * args , ** kwargs )
993
1030
994
1031
995
1032
class OVPipelinePart (ConfigMixin ):
0 commit comments