Skip to content

Commit dcc928b

Browse files
authored
Merge PR #169 from Kosinkadink/develop - initial flux support
Initial flux support, refactoring weight control
2 parents 949843e + ef16e3c commit dcc928b

7 files changed

+552
-218
lines changed

adv_control/control.py

+61-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import comfy.model_management
99
import comfy.model_detection
1010
import comfy.controlnet as comfy_cn
11-
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter
11+
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, StrengthType
1212
from comfy.model_patcher import ModelPatcher
1313

1414
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
@@ -21,13 +21,23 @@
2121

2222

2323
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
24-
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
24+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
2525
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
2626
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
27+
self.is_flux = False
28+
self.x_noisy_shape = None
2729

2830
def get_universal_weights(self) -> ControlWeights:
29-
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
30-
return self.weights.copy_with_new_weights(raw_weights)
31+
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
32+
if key == "middle":
33+
return 1.0
34+
c_len = len(control[key])
35+
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
36+
raw_weights = raw_weights[:-1]
37+
if key == "input":
38+
raw_weights.reverse()
39+
return raw_weights[idx]
40+
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
3141

3242
def get_control_advanced(self, x_noisy, t, cond, batched_number):
3343
# perform special version of get_control that supports sliding context and masks
@@ -49,7 +59,6 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
4959
if self.manual_cast_dtype is not None:
5060
dtype = self.manual_cast_dtype
5161

52-
output_dtype = x_noisy.dtype
5362
# make cond_hint appropriate dimensions
5463
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
5564
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
@@ -64,9 +73,9 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
6473
actual_cond_hint_orig = self.cond_hint_original
6574
if self.cond_hint_original.size(0) < self.full_latent_length:
6675
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
67-
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
76+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
6877
else:
69-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
78+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
7079
if self.vae is not None:
7180
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
7281
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -81,25 +90,44 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
8190
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
8291

8392
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
84-
y = cond.get('y', None)
85-
if y is not None:
86-
y = y.to(dtype)
93+
extra = self.extra_args.copy()
94+
for c in self.extra_conds:
95+
temp = cond.get(c, None)
96+
if temp is not None:
97+
extra[c] = temp.to(dtype)
98+
8799
timestep = self.model_sampling_current.timestep(t)
88100
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
101+
self.x_noisy_shape = x_noisy.shape
102+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
103+
return self.control_merge(control, control_prev, output_dtype=None)
89104

90-
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
91-
return self.control_merge(control, control_prev, output_dtype)
105+
def pre_run_advanced(self, *args, **kwargs):
106+
self.is_flux = "Flux" in str(type(self.control_model).__name__)
107+
return super().pre_run_advanced(*args, **kwargs)
108+
109+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None):
110+
if self.is_flux:
111+
flux_shape = self.x_noisy_shape
112+
return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape)
92113

93114
def copy(self):
94115
c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
116+
c.control_model = self.control_model
117+
c.control_model_wrapped = self.control_model_wrapped
95118
self.copy_to(c)
96119
self.copy_to_advanced(c)
97120
return c
98121

122+
def cleanup_advanced(self):
123+
self.x_noisy_shape = None
124+
return super().cleanup_advanced()
125+
99126
@staticmethod
100127
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
101128
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
102-
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
129+
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device,
130+
manual_cast_dtype=v.manual_cast_dtype)
103131
v.copy_to(to_return)
104132
return to_return
105133

@@ -121,18 +149,28 @@ def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, o
121149
return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype)
122150

123151
def get_universal_weights(self) -> ControlWeights:
124-
raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)]
125-
raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
126-
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
127-
raw_weights.reverse() # need to reverse to match recent ComfyUI changes
128-
return self.weights.copy_with_new_weights(raw_weights)
152+
def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
153+
if key == "middle":
154+
return 1.0
155+
c_len = 8 #len(control[key])
156+
raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)]
157+
raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
158+
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
159+
if key == "input":
160+
raw_weights.reverse()
161+
return raw_weights[idx]
162+
return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func)
129163

130164
def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
165+
if key == "middle":
166+
return 0
131167
# match how T2IAdapterAdvanced deals with universal weights
132-
indeces = [7 - i for i in range(8)]
133-
indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]]
168+
c_len = 8 #len(control[key])
169+
indeces = [(c_len-1) - i for i in range(c_len)]
170+
indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]]
134171
indeces = get_properly_arranged_t2i_weights(indeces)
135-
indeces.reverse() # need to reverse to match recent ComfyUI changes
172+
if key == "input":
173+
indeces.reverse() # need to reverse to match recent ComfyUI changes
136174
return indeces[idx]
137175

138176
def get_control_advanced(self, x_noisy, t, cond, batched_number):
@@ -381,11 +419,11 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
381419
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
382420
return self.control_merge(control, control_prev, output_dtype)
383421

384-
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
422+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs):
385423
# apply mults to indexes with and without a direct condhint
386424
x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
387425
x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
388-
return super().apply_advanced_strengths_and_masks(x, batched_number)
426+
return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs)
389427

390428
def pre_run_advanced(self, model, percent_to_timestep_function):
391429
super().pre_run_advanced(model, percent_to_timestep_function)

adv_control/control_plusplus.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,16 @@ def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: Timest
237237
self.single_control_type: str = None
238238

239239
def get_universal_weights(self) -> ControlWeights:
240-
# TODO: match actual layer count of model
241-
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
242-
return self.weights.copy_with_new_weights(raw_weights)
240+
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
241+
if key == "middle":
242+
return 1.0
243+
c_len = len(control[key])
244+
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
245+
raw_weights = raw_weights[:-1]
246+
if key == "input":
247+
raw_weights.reverse()
248+
return raw_weights[idx]
249+
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
243250

244251
def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
245252
if pp_group is not None:

0 commit comments

Comments
 (0)