Skip to content

Commit 7a456aa

Browse files
authored
Merge pull request #123 from Kosinkadink/sd3-changes
SD3 ControlNet support + New ComfyUI Compatibility
2 parents bf16347 + 3d00251 commit 7a456aa

8 files changed

+168
-136
lines changed

adv_control/control.py

+57-45
Large diffs are not rendered by default.

adv_control/control_reference.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def create_combo(reference_type: str, style_fidelity: float, ref_weight: float,
218218

219219

220220
class ReferencePreprocWrapper(AbstractPreprocWrapper):
221-
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
221+
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
222222
def __init__(self, condhint: Tensor):
223223
super().__init__(condhint)
224224

@@ -228,10 +228,12 @@ class ReferenceAdvanced(ControlBase, AdvancedControlBase):
228228

229229
def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
230230
super().__init__(device)
231-
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
231+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
232+
# TODO: allow vae_optional to be used instead of preprocessor
233+
#require_vae=True
232234
self.ref_opts = ref_opts
233235
self.order = 0
234-
self.latent_format = None
236+
self.model_latent_format = None
235237
self.model_sampling_current = None
236238
self.should_apply_attn_effective_strength = False
237239
self.should_apply_adain_effective_strength = False
@@ -288,9 +290,9 @@ def should_run(self):
288290

289291
def pre_run_advanced(self, model, percent_to_timestep_function):
290292
AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
291-
if type(self.cond_hint_original) == ReferencePreprocWrapper:
293+
if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
292294
self.cond_hint_original = self.cond_hint_original.condhint
293-
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
295+
self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
294296
self.model_sampling_current = model.model_sampling
295297
# SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments
296298
if type(model).__name__ == "SDXL":
@@ -328,7 +330,7 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
328330
if x_noisy.shape[0] != self.cond_hint.shape[0]:
329331
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
330332
# noise cond_hint based on sigma (current step)
331-
self.cond_hint = self.latent_format.process_in(self.cond_hint)
333+
self.cond_hint = self.model_latent_format.process_in(self.cond_hint)
332334
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
333335
timestep = self.model_sampling_current.timestep(t)
334336
self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
@@ -343,8 +345,8 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
343345

344346
def cleanup_advanced(self):
345347
super().cleanup_advanced()
346-
del self.latent_format
347-
self.latent_format = None
348+
del self.model_latent_format
349+
self.model_latent_format = None
348350
del self.model_sampling_current
349351
self.model_sampling_current = None
350352
self.should_apply_attn_effective_strength = False

adv_control/control_sparsectrl.py

+10-35
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from comfy.ldm.modules.attention import SpatialTransformer
2424
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
2525
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
26-
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample
26+
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
2727
from comfy.model_patcher import ModelPatcher
2828
import comfy.ops
2929
import comfy.model_management
3030

3131
from .logger import logger
32-
from .utils import (BIGMAX, TimestepKeyframeGroup, disable_weight_init_clean_groupnorm,
32+
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm,
3333
prepare_mask_batch, broadcast_image_to_extend, extend_to_batch_size)
3434

3535

@@ -85,7 +85,8 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
8585
x = torch.zeros_like(x)
8686
guided_hint = self.input_hint_block(hint, emb, context)
8787

88-
outs = []
88+
out_output = []
89+
out_middle = []
8990

9091
hs = []
9192
if self.num_classes is not None:
@@ -100,12 +101,12 @@ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs)
100101
guided_hint = None
101102
else:
102103
h = module(h, emb, context)
103-
outs.append(zero_conv(h, emb, context))
104+
out_output.append(zero_conv(h, emb, context))
104105

105106
h = self.middle_block(h, emb, context)
106-
outs.append(self.middle_block_out(h, emb, context))
107+
out_middle.append(self.middle_block_out(h, emb, context))
107108

108-
return outs
109+
return {"middle": out_middle, "output": out_output}
109110

110111

111112
class SparseModelPatcher(ModelPatcher):
@@ -154,36 +155,10 @@ def clone(self):
154155
self.object_patches_backup = n.object_patches_backup
155156

156157

157-
class PreprocSparseRGBWrapper:
158-
error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
158+
class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
159+
error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
159160
def __init__(self, condhint: Tensor):
160-
self.condhint = condhint
161-
162-
def movedim(self, *args, **kwargs):
163-
return self
164-
165-
def __getattr__(self, *args, **kwargs):
166-
raise AttributeError(self.error_msg)
167-
168-
def __setattr__(self, name, value):
169-
if name != "condhint":
170-
raise AttributeError(self.error_msg)
171-
super().__setattr__(name, value)
172-
173-
def __iter__(self, *args, **kwargs):
174-
raise AttributeError(self.error_msg)
175-
176-
def __next__(self, *args, **kwargs):
177-
raise AttributeError(self.error_msg)
178-
179-
def __len__(self, *args, **kwargs):
180-
raise AttributeError(self.error_msg)
181-
182-
def __getitem__(self, *args, **kwargs):
183-
raise AttributeError(self.error_msg)
184-
185-
def __setitem__(self, *args, **kwargs):
186-
raise AttributeError(self.error_msg)
161+
super().__init__(condhint)
187162

188163

189164
class SparseContextAware:

adv_control/control_svd.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
311311

312312
guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
313313

314-
outs = []
314+
out_output = []
315+
out_middle = []
315316

316317
hs = []
317318
if self.num_classes is not None:
@@ -326,12 +327,12 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
326327
guided_hint = None
327328
else:
328329
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
329-
outs.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
330+
out_output.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
330331

331332
h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
332-
outs.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
333+
out_middle.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
333334

334-
return outs
335+
return {"middle": out_middle, "output": out_output}
335336

336337

337338
TEMPORAL_TRANSFORMER_BLOCKS = {

adv_control/nodes.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import folder_paths
55
from comfy.model_patcher import ModelPatcher
66

7-
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet
8-
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, BIGMAX
7+
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet, is_sd3_advanced_controlnet
8+
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, AbstractPreprocWrapper, BIGMAX
99
from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights, SoftControlNetWeights, CustomControlNetWeights,
1010
SoftT2IAdapterWeights, CustomT2IAdapterWeights)
1111
from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
@@ -89,6 +89,7 @@ def INPUT_TYPES(s):
8989
"latent_kf_override": ("LATENT_KEYFRAME", ),
9090
"weights_override": ("CONTROL_NET_WEIGHTS", ),
9191
"model_optional": ("MODEL",),
92+
"vae_optional": ("VAE",),
9293
}
9394
}
9495

@@ -99,7 +100,7 @@ def INPUT_TYPES(s):
99100
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
100101

101102
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
102-
mask_optional: Tensor=None, model_optional: ModelPatcher=None,
103+
mask_optional: Tensor=None, model_optional: ModelPatcher=None, vae_optional=None,
103104
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
104105
weights_override: ControlWeights=None):
105106
if strength == 0:
@@ -121,7 +122,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
121122
c_net = cnets[prev_cnet]
122123
else:
123124
# copy, convert to advanced if needed, and set cond
124-
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent))
125+
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent), vae_optional)
125126
if is_advanced_controlnet(c_net):
126127
# disarm node check
127128
c_net.disarm()
@@ -130,6 +131,17 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
130131
if not model_optional:
131132
raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
132133
c_net.patch_model(model=model_optional)
134+
# if vae required, verify vae is passed in
135+
if c_net.require_vae:
136+
# if controlnet can accept preprocced condhint latents and is the case, ignore vae requirement
137+
if c_net.allow_condhint_latents and isinstance(control_hint, AbstractPreprocWrapper):
138+
pass
139+
elif not vae_optional:
140+
# make sure SD3 ControlNet will get a special message instead of generic type mention
141+
if is_sd3_advanced_controlnet:
142+
raise Exception(f"SD3 ControlNet requires vae_optional input, but got None.")
143+
else:
144+
raise Exception(f"Type '{type(c_net).__name__}' requires vae_optional input, but got None.")
133145
# apply optional parameters and overrides, if provided
134146
if timestep_kf is not None:
135147
c_net.set_timestep_keyframes(timestep_kf)

adv_control/nodes_weight.py

+2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
198198
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
199199
weights = [weight_00, weight_01, weight_02, weight_03]
200200
weights = get_properly_arranged_t2i_weights(weights)
201+
weights.reverse() # to account for recent ComfyUI changes
201202
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
202203
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
203204

@@ -229,5 +230,6 @@ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
229230
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
230231
weights = [weight_00, weight_01, weight_02, weight_03]
231232
weights = get_properly_arranged_t2i_weights(weights)
233+
weights.reverse() # to account for recent ComfyUI changes
232234
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
233235
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))

0 commit comments

Comments
 (0)