8
8
import comfy .model_management
9
9
import comfy .model_detection
10
10
import comfy .controlnet as comfy_cn
11
- from comfy .controlnet import ControlBase , ControlNet , ControlLora , T2IAdapter
11
+ from comfy .controlnet import ControlBase , ControlNet , ControlLora , T2IAdapter , StrengthType
12
12
from comfy .model_patcher import ModelPatcher
13
13
14
14
from .control_sparsectrl import SparseModelPatcher , SparseControlNet , SparseCtrlMotionWrapper , SparseSettings , SparseConst
21
21
22
22
23
23
class ControlNetAdvanced (ControlNet , AdvancedControlBase ):
24
- def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None ):
24
+ def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = [ "y" ], strength_type = StrengthType . CONSTANT ):
25
25
super ().__init__ (control_model = control_model , global_average_pooling = global_average_pooling , compression_ratio = compression_ratio , latent_format = latent_format , device = device , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
26
26
AdvancedControlBase .__init__ (self , super (), timestep_keyframes = timestep_keyframes , weights_default = ControlWeights .controlnet ())
27
+ self .is_flux = False
28
+ self .x_noisy_shape = None
27
29
28
30
def get_universal_weights (self ) -> ControlWeights :
29
- raw_weights = [(self .weights .base_multiplier ** float (12 - i )) for i in range (13 )]
30
- return self .weights .copy_with_new_weights (raw_weights )
31
+ def cn_weights_func (idx : int , control : dict [str , list [Tensor ]], key : str ):
32
+ if key == "middle" :
33
+ return 1.0
34
+ c_len = len (control [key ])
35
+ raw_weights = [(self .weights .base_multiplier ** float ((c_len ) - i )) for i in range (c_len + 1 )]
36
+ raw_weights = raw_weights [:- 1 ]
37
+ if key == "input" :
38
+ raw_weights .reverse ()
39
+ return raw_weights [idx ]
40
+ return self .weights .copy_with_new_weights (new_weight_func = cn_weights_func )
31
41
32
42
def get_control_advanced (self , x_noisy , t , cond , batched_number ):
33
43
# perform special version of get_control that supports sliding context and masks
@@ -49,7 +59,6 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
49
59
if self .manual_cast_dtype is not None :
50
60
dtype = self .manual_cast_dtype
51
61
52
- output_dtype = x_noisy .dtype
53
62
# make cond_hint appropriate dimensions
54
63
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
55
64
if self .sub_idxs is not None or self .cond_hint is None or x_noisy .shape [2 ] * self .compression_ratio != self .cond_hint .shape [2 ] or x_noisy .shape [3 ] * self .compression_ratio != self .cond_hint .shape [3 ]:
@@ -64,9 +73,9 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
64
73
actual_cond_hint_orig = self .cond_hint_original
65
74
if self .cond_hint_original .size (0 ) < self .full_latent_length :
66
75
actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
67
- self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , 'nearest-exact' , "center" )
76
+ self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , self . upscale_algorithm , "center" )
68
77
else :
69
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , 'nearest-exact' , "center" )
78
+ self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * compression_ratio , x_noisy .shape [2 ] * compression_ratio , self . upscale_algorithm , "center" )
70
79
if self .vae is not None :
71
80
loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
72
81
self .cond_hint = self .vae .encode (self .cond_hint .movedim (1 , - 1 ))
@@ -81,25 +90,44 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
81
90
self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number , dtype = dtype )
82
91
83
92
context = cond .get ('crossattn_controlnet' , cond ['c_crossattn' ])
84
- y = cond .get ('y' , None )
85
- if y is not None :
86
- y = y .to (dtype )
93
+ extra = self .extra_args .copy ()
94
+ for c in self .extra_conds :
95
+ temp = cond .get (c , None )
96
+ if temp is not None :
97
+ extra [c ] = temp .to (dtype )
98
+
87
99
timestep = self .model_sampling_current .timestep (t )
88
100
x_noisy = self .model_sampling_current .calculate_input (t , x_noisy )
101
+ self .x_noisy_shape = x_noisy .shape
102
+ control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .to (dtype ), context = context .to (dtype ), ** extra )
103
+ return self .control_merge (control , control_prev , output_dtype = None )
89
104
90
- control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = context .to (dtype ), y = y )
91
- return self .control_merge (control , control_prev , output_dtype )
105
+ def pre_run_advanced (self , * args , ** kwargs ):
106
+ self .is_flux = "Flux" in str (type (self .control_model ).__name__ )
107
+ return super ().pre_run_advanced (* args , ** kwargs )
108
+
109
+ def apply_advanced_strengths_and_masks (self , x : Tensor , batched_number : int , flux_shape = None ):
110
+ if self .is_flux :
111
+ flux_shape = self .x_noisy_shape
112
+ return super ().apply_advanced_strengths_and_masks (x , batched_number , flux_shape )
92
113
93
114
def copy (self ):
94
115
c = ControlNetAdvanced (self .control_model , self .timestep_keyframes , global_average_pooling = self .global_average_pooling , load_device = self .load_device , manual_cast_dtype = self .manual_cast_dtype )
116
+ c .control_model = self .control_model
117
+ c .control_model_wrapped = self .control_model_wrapped
95
118
self .copy_to (c )
96
119
self .copy_to_advanced (c )
97
120
return c
98
121
122
+ def cleanup_advanced (self ):
123
+ self .x_noisy_shape = None
124
+ return super ().cleanup_advanced ()
125
+
99
126
@staticmethod
100
127
def from_vanilla (v : ControlNet , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'ControlNetAdvanced' :
101
128
to_return = ControlNetAdvanced (control_model = v .control_model , timestep_keyframes = timestep_keyframe ,
102
- global_average_pooling = v .global_average_pooling , compression_ratio = v .compression_ratio , latent_format = v .latent_format , device = v .device , load_device = v .load_device , manual_cast_dtype = v .manual_cast_dtype )
129
+ global_average_pooling = v .global_average_pooling , compression_ratio = v .compression_ratio , latent_format = v .latent_format , device = v .device , load_device = v .load_device ,
130
+ manual_cast_dtype = v .manual_cast_dtype )
103
131
v .copy_to (to_return )
104
132
return to_return
105
133
@@ -121,18 +149,28 @@ def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, o
121
149
return AdvancedControlBase .control_merge_inject (self , control , control_prev , output_dtype )
122
150
123
151
def get_universal_weights (self ) -> ControlWeights :
124
- raw_weights = [(self .weights .base_multiplier ** float (7 - i )) for i in range (8 )]
125
- raw_weights = [raw_weights [- 8 ], raw_weights [- 3 ], raw_weights [- 2 ], raw_weights [- 1 ]]
126
- raw_weights = get_properly_arranged_t2i_weights (raw_weights )
127
- raw_weights .reverse () # need to reverse to match recent ComfyUI changes
128
- return self .weights .copy_with_new_weights (raw_weights )
152
+ def t2i_weights_func (idx : int , control : dict [str , list [Tensor ]], key : str ):
153
+ if key == "middle" :
154
+ return 1.0
155
+ c_len = 8 #len(control[key])
156
+ raw_weights = [(self .weights .base_multiplier ** float ((c_len - 1 ) - i )) for i in range (c_len )]
157
+ raw_weights = [raw_weights [- c_len ], raw_weights [- 3 ], raw_weights [- 2 ], raw_weights [- 1 ]]
158
+ raw_weights = get_properly_arranged_t2i_weights (raw_weights )
159
+ if key == "input" :
160
+ raw_weights .reverse ()
161
+ return raw_weights [idx ]
162
+ return self .weights .copy_with_new_weights (new_weight_func = t2i_weights_func )
129
163
130
164
def get_calc_pow (self , idx : int , control : dict [str , list [Tensor ]], key : str ) -> int :
165
+ if key == "middle" :
166
+ return 0
131
167
# match how T2IAdapterAdvanced deals with universal weights
132
- indeces = [7 - i for i in range (8 )]
133
- indeces = [indeces [- 8 ], indeces [- 3 ], indeces [- 2 ], indeces [- 1 ]]
168
+ c_len = 8 #len(control[key])
169
+ indeces = [(c_len - 1 ) - i for i in range (c_len )]
170
+ indeces = [indeces [- c_len ], indeces [- 3 ], indeces [- 2 ], indeces [- 1 ]]
134
171
indeces = get_properly_arranged_t2i_weights (indeces )
135
- indeces .reverse () # need to reverse to match recent ComfyUI changes
172
+ if key == "input" :
173
+ indeces .reverse () # need to reverse to match recent ComfyUI changes
136
174
return indeces [idx ]
137
175
138
176
def get_control_advanced (self , x_noisy , t , cond , batched_number ):
@@ -381,11 +419,11 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
381
419
control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = context .to (dtype ), y = y )
382
420
return self .control_merge (control , control_prev , output_dtype )
383
421
384
- def apply_advanced_strengths_and_masks (self , x : Tensor , batched_number : int ):
422
+ def apply_advanced_strengths_and_masks (self , x : Tensor , batched_number : int , * args , ** kwargs ):
385
423
# apply mults to indexes with and without a direct condhint
386
424
x [self .local_sparse_idxs ] *= self .sparse_settings .sparse_hint_mult * self .weights .extras .get (SparseConst .HINT_MULT , 1.0 )
387
425
x [self .local_sparse_idxs_inverse ] *= self .sparse_settings .sparse_nonhint_mult * self .weights .extras .get (SparseConst .NONHINT_MULT , 1.0 )
388
- return super ().apply_advanced_strengths_and_masks (x , batched_number )
426
+ return super ().apply_advanced_strengths_and_masks (x , batched_number , * args , ** kwargs )
389
427
390
428
def pre_run_advanced (self , model , percent_to_timestep_function ):
391
429
super ().pre_run_advanced (model , percent_to_timestep_function )
0 commit comments