Skip to content

Commit 95441b6

Browse files
authored
Merge PR #116 from Kosinkadink/develop - SparseCtrl upgrades + bugfixes
SparseCtrl Upgrades + bug fixes
2 parents 6788448 + 576426a commit 95441b6

8 files changed

+328
-126
lines changed

adv_control/control.py

+65-26
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33
import torch
44
import os
55

6+
import comfy.ops
67
import comfy.utils
78
import comfy.model_management
89
import comfy.model_detection
910
import comfy.controlnet as comfy_cn
10-
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, broadcast_image_to
11+
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter
1112
from comfy.model_patcher import ModelPatcher
1213

13-
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
14+
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper, SparseConst
1415
from .control_lllite import LLLiteModule, LLLitePatch
1516
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
1617
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, ControlWeightType, ControlWeights, WeightTypeException,
17-
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory)
18+
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
19+
broadcast_image_to_extend, extend_to_batch_size)
1820
from .logger import logger
1921

2022

@@ -55,12 +57,15 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
5557
del self.cond_hint
5658
self.cond_hint = None
5759
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
58-
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
59-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
60+
if self.sub_idxs is not None:
61+
actual_cond_hint_orig = self.cond_hint_original
62+
if self.cond_hint_original.size(0) < self.full_latent_length:
63+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
64+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
6065
else:
6166
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
6267
if x_noisy.shape[0] != self.cond_hint.shape[0]:
63-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
68+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
6469

6570
# prepare mask_cond_hint
6671
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
@@ -97,7 +102,7 @@ def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channel
97102

98103
def control_merge_inject(self, control_input, control_output, control_prev, output_dtype):
99104
# if has uncond multiplier, need to make sure control shapes are the same batch size as expected
100-
if self.weights.has_uncond_multiplier:
105+
if self.weights.has_uncond_multiplier or self.weights.has_uncond_mask:
101106
if control_input is not None:
102107
for i in range(len(control_input)):
103108
x = control_input[i]
@@ -131,9 +136,12 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
131136
if self.sub_idxs is not None:
132137
# cond hints
133138
full_cond_hint_original = self.cond_hint_original
139+
actual_cond_hint_orig = full_cond_hint_original
134140
del self.cond_hint
135141
self.cond_hint = None
136-
self.cond_hint_original = full_cond_hint_original[self.sub_idxs]
142+
if full_cond_hint_original.size(0) < self.full_latent_length:
143+
actual_cond_hint_orig = extend_to_batch_size(tensor=full_cond_hint_original, batch_size=full_cond_hint_original.size(0))
144+
self.cond_hint_original = actual_cond_hint_orig[self.sub_idxs]
137145
# mask hints
138146
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
139147
return super().get_control(x_noisy, t, cond, batched_number)
@@ -221,12 +229,15 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
221229
del self.cond_hint
222230
self.cond_hint = None
223231
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
224-
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
225-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
232+
if self.sub_idxs is not None:
233+
actual_cond_hint_orig = self.cond_hint_original
234+
if self.cond_hint_original.size(0) < self.full_latent_length:
235+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
236+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
226237
else:
227238
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
228239
if x_noisy.shape[0] != self.cond_hint.shape[0]:
229-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
240+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
230241

231242
# prepare mask_cond_hint
232243
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
@@ -291,18 +302,33 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
291302
del self.cond_hint
292303
self.cond_hint = None
293304
# first, figure out which cond idxs are relevant, and where they fit in
294-
cond_idxs = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length)
295-
305+
cond_idxs, hint_order = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length,
306+
sub_idxs=self.sub_idxs if self.sparse_settings.is_context_aware() else None)
296307
range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs
297308
hint_idxs = [] # idxs in cond_idxs
298-
local_idxs = [] # idx to pun in final cond_hint
309+
local_idxs = [] # idx to put in final cond_hint
299310
for i,cond_idx in enumerate(cond_idxs):
300311
if cond_idx in range_idxs:
301312
hint_idxs.append(i)
302313
local_idxs.append(range_idxs.index(cond_idx))
314+
# log_string = f"cond_idxs: {cond_idxs}, local_idxs: {local_idxs}, hint_idxs: {hint_idxs}, hint_order: {hint_order}"
315+
# if self.sub_idxs is not None:
316+
# log_string += f" sub_idxs: {self.sub_idxs[0]}-{self.sub_idxs[-1]}"
317+
# logger.warn(log_string)
318+
# determine cond/uncond indexes that will get masked
319+
self.local_sparse_idxs = []
320+
self.local_sparse_idxs_inverse = list(range(x_noisy.size(0)))
321+
for batch_idx in range(batched_number):
322+
for i in local_idxs:
323+
actual_i = i+(batch_idx*actual_length)
324+
self.local_sparse_idxs.append(actual_i)
325+
if actual_i in self.local_sparse_idxs_inverse:
326+
self.local_sparse_idxs_inverse.remove(actual_i)
303327
# sub_cond_hint now contains the hints relevant to current x_noisy
304-
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(self.device)
305-
328+
if hint_order is None:
329+
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(self.device)
330+
else:
331+
sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(self.device)
306332
# scale cond_hints to match noisy input
307333
if self.control_model.use_simplified_conditioning_embedding:
308334
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
@@ -319,15 +345,15 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
319345
# prepare cond_mask (b, 1, h, w)
320346
cond_shape[1] = 1
321347
cond_mask = torch.zeros(cond_shape).to(dtype).to(self.device)
322-
cond_mask[local_idxs] = 1.0
348+
cond_mask[local_idxs] = self.sparse_settings.sparse_mask_mult * self.weights.extras.get(SparseConst.MASK_MULT, 1.0)
323349
# combine cond_hint and cond_mask into (b, c+1, h, w)
324350
if not self.sparse_settings.merged:
325351
self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1)
326352
del sub_cond_hint
327353
del cond_mask
328354
# make cond_hint match x_noisy batch
329355
if x_noisy.shape[0] != self.cond_hint.shape[0]:
330-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
356+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
331357

332358
# prepare mask_cond_hint
333359
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
@@ -342,6 +368,12 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
342368
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
343369
return self.control_merge(None, control, control_prev, output_dtype)
344370

371+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
372+
# apply mults to indexes with and without a direct condhint
373+
x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
374+
x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
375+
return super().apply_advanced_strengths_and_masks(x, batched_number)
376+
345377
def pre_run_advanced(self, model, percent_to_timestep_function):
346378
super().pre_run_advanced(model, percent_to_timestep_function)
347379
if type(self.cond_hint_original) == PreprocSparseRGBWrapper:
@@ -359,6 +391,8 @@ def cleanup_advanced(self):
359391
if self.latent_format is not None:
360392
del self.latent_format
361393
self.latent_format = None
394+
self.local_sparse_idxs = None
395+
self.local_sparse_idxs_inverse = None
362396

363397
def copy(self):
364398
c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.device, self.load_device, self.manual_cast_dtype)
@@ -411,12 +445,15 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
411445
del self.cond_hint
412446
self.cond_hint = None
413447
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
414-
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
415-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
448+
if self.sub_idxs is not None:
449+
actual_cond_hint_orig = self.cond_hint_original
450+
if self.cond_hint_original.size(0) < self.full_latent_length:
451+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
452+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
416453
else:
417454
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
418455
if x_noisy.shape[0] != self.cond_hint.shape[0]:
419-
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
456+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
420457
# some special logic here compared to other controlnets:
421458
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
422459
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
@@ -551,13 +588,9 @@ def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, tim
551588
motion_data[key] = controlnet_data.pop(key)
552589
if len(motion_data) == 0:
553590
raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!")
554-
motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data).to(comfy.model_management.unet_dtype())
555-
missing, unexpected = motion_wrapper.load_state_dict(motion_data)
556-
if len(missing) > 0 or len(unexpected) > 0:
557-
logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}")
558591

559592
# now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
560-
controlnet_config = None
593+
controlnet_config: dict[str] = None
561594
is_diffusers = False
562595
use_simplified_conditioning_embedding = False
563596
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data:
@@ -681,6 +714,12 @@ class WeightsLoader(torch.nn.Module):
681714
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
682715
global_average_pooling = True
683716

717+
# actually load motion portion of model now
718+
motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data, ops=controlnet_config.get("operations", None)).to(comfy.model_management.unet_dtype())
719+
missing, unexpected = motion_wrapper.load_state_dict(motion_data)
720+
if len(missing) > 0 or len(unexpected) > 0:
721+
logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}")
722+
684723
# both motion portion and controlnet portions are loaded; bring them together if using motion model
685724
if sparse_settings.use_motion:
686725
motion_wrapper.inject(control_model)

adv_control/control_reference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .logger import logger
1616
from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, AbstractPreprocWrapper,
17-
deepcopy_with_sharing, prepare_mask_batch, broadcast_image_to_full)
17+
deepcopy_with_sharing, prepare_mask_batch, broadcast_image_to_extend)
1818

1919

2020
def refcn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
@@ -326,7 +326,7 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
326326
self.cond_hint_original,
327327
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
328328
if x_noisy.shape[0] != self.cond_hint.shape[0]:
329-
self.cond_hint = broadcast_image_to_full(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
329+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
330330
# noise cond_hint based on sigma (current step)
331331
self.cond_hint = self.latent_format.process_in(self.cond_hint)
332332
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)

0 commit comments

Comments
 (0)