@@ -889,9 +889,7 @@ def reshape(
889
889
)
890
890
891
891
if self .text_encoder_3 is not None :
892
- self .text_encoder_3 .model = self ._reshape_text_encoder (
893
- self .text_encoder_3 .model , batch_size , getattr (self .tokenizer_3 , "model_max_length" , - 1 )
894
- )
892
+ self .text_encoder_3 .model = self ._reshape_text_encoder (self .text_encoder_3 .model , batch_size , - 1 )
895
893
896
894
self .clear_requests ()
897
895
return self
@@ -973,6 +971,63 @@ def __call__(self, *args, **kwargs):
973
971
for k , v in kwargs .items ():
974
972
kwargs [k ] = np_to_pt_generators (v , self .device )
975
973
974
+ height , width = None , None
975
+ height_idx , width_idx = None , None
976
+ shapes_overriden = False
977
+ sig = inspect .signature (self .auto_model_class .__call__ )
978
+ sig_height_idx = list (sig .parameters ).index ("height" ) if "height" in sig .parameters else len (sig .parameters )
979
+ sig_width_idx = list (sig .parameters ).index ("width" ) if "width" in sig .parameters else len (sig .parameters )
980
+ if "height" in kwargs :
981
+ height = kwargs ["height" ]
982
+ elif len (args ) > sig_height_idx :
983
+ height = args [sig_height_idx ]
984
+ height_idx = sig_height_idx
985
+
986
+ if "width" in kwargs :
987
+ width = kwargs ["width" ]
988
+ elif len (args ) > sig_width_idx :
989
+ width = args [sig_width_idx ]
990
+ width_idx = sig_width_idx
991
+
992
+ if self .height != - 1 :
993
+ if height is not None and height != self .height :
994
+ logger .warning (f"Incompatible height argument provided { height } . Pipeline only support { self .height } ." )
995
+ height = self .height
996
+ else :
997
+ height = self .height
998
+
999
+ if height_idx is not None :
1000
+ args [height_idx ] = height
1001
+ else :
1002
+ kwargs ["height" ] = height
1003
+
1004
+ shapes_overriden = True
1005
+
1006
+ if self .width != - 1 :
1007
+ if width is not None and width != self .width :
1008
+ logger .warning (f"Incompatible widtth argument provided { width } . Pipeline only support { self .width } ." )
1009
+ width = self .width
1010
+ else :
1011
+ width = self .width
1012
+
1013
+ if width_idx is not None :
1014
+ args [width_idx ] = width
1015
+ else :
1016
+ kwargs ["width" ] = width
1017
+ shapes_overriden = True
1018
+
1019
+ # Sana generates images in specific resolution grid size and then resize to requested size by default, it may contradict with pipeline height / width
1020
+ # Disable this behavior for static shape pipeline
1021
+ if self .auto_model_class .__name__ .startswith ("Sana" ) and shapes_overriden :
1022
+ sig_resolution_bining_idx = (
1023
+ list (sig .parameters ).index ("use_resolution_binning" )
1024
+ if "use_resolution_binning" in sig .parameters
1025
+ else len (sig .parameters )
1026
+ )
1027
+ if len (args ) > sig_resolution_bining_idx :
1028
+ args [sig_resolution_bining_idx ] = False
1029
+ else :
1030
+ kwargs ["use_resolution_binning" ] = False
976
1031
# we use auto_model_class.__call__ here because we can't call super().__call__
977
1032
# as OptimizedModel already defines a __call__ which is the first in the MRO
978
1033
return self .auto_model_class .__call__ (self , * args , ** kwargs )
0 commit comments