Skip to content

Commit a91a3ac

Browse files
authored
Merge PR #141 - ControlLLLite refactor + vanilla CN conversion w/ context_opts
ControlLLLite refactor + vanilla CN conversion when using sliding context
2 parents d3c6ae0 + 07b8e3e commit a91a3ac

7 files changed

+547
-312
lines changed

adv_control/control.py

+119-162
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
from comfy.model_patcher import ModelPatcher
1313

1414
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
15-
from .control_lllite import LLLiteModule, LLLitePatch
15+
from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite
1616
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
1717
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException,
1818
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
1919
broadcast_image_to_extend, extend_to_batch_size)
2020
from .logger import logger
2121

2222

23+
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"
24+
25+
2326
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
2427
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
2528
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
@@ -98,8 +101,10 @@ def copy(self):
98101

99102
@staticmethod
100103
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
101-
return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
104+
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
102105
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
106+
v.copy_to(to_return)
107+
return to_return
103108

104109

105110
class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
@@ -166,8 +171,10 @@ def cleanup(self):
166171

167172
@staticmethod
168173
def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
169-
return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
174+
to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
170175
compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
176+
v.copy_to(to_return)
177+
return to_return
171178

172179

173180
class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
@@ -194,8 +201,10 @@ def cleanup(self):
194201

195202
@staticmethod
196203
def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
197-
return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
204+
to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
198205
global_average_pooling=v.global_average_pooling, device=v.device)
206+
v.copy_to(to_return)
207+
return to_return
199208

200209

201210
class SVDControlNetAdvanced(ControlNetAdvanced):
@@ -408,115 +417,6 @@ def copy(self):
408417
return c
409418

410419

411-
class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
412-
# This ControlNet is more of an attention patch than a traditional controlnet
413-
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device=None):
414-
super().__init__(device)
415-
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), require_model=True)
416-
self.patch_attn1 = patch_attn1.set_control(self)
417-
self.patch_attn2 = patch_attn2.set_control(self)
418-
self.latent_dims_div2 = None
419-
self.latent_dims_div4 = None
420-
421-
def patch_model(self, model: ModelPatcher):
422-
model.set_model_attn1_patch(self.patch_attn1)
423-
model.set_model_attn2_patch(self.patch_attn2)
424-
425-
def set_cond_hint_inject(self, *args, **kwargs):
426-
to_return = super().set_cond_hint_inject(*args, **kwargs)
427-
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
428-
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
429-
return to_return
430-
431-
def pre_run_advanced(self, *args, **kwargs):
432-
AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
433-
#logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
434-
self.patch_attn1.set_control(self)
435-
self.patch_attn2.set_control(self)
436-
#logger.warn(f"in pre_run_advanced: {id(self)}")
437-
438-
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
439-
# normal ControlNet stuff
440-
control_prev = None
441-
if self.previous_controlnet is not None:
442-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
443-
444-
if self.timestep_range is not None:
445-
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
446-
return control_prev
447-
448-
dtype = x_noisy.dtype
449-
# prepare cond_hint
450-
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
451-
if self.cond_hint is not None:
452-
del self.cond_hint
453-
self.cond_hint = None
454-
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
455-
if self.sub_idxs is not None:
456-
actual_cond_hint_orig = self.cond_hint_original
457-
if self.cond_hint_original.size(0) < self.full_latent_length:
458-
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
459-
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
460-
else:
461-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
462-
if x_noisy.shape[0] != self.cond_hint.shape[0]:
463-
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
464-
# some special logic here compared to other controlnets:
465-
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
466-
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
467-
divisible_by_2_h = x_noisy.shape[2]%2==0
468-
divisible_by_2_w = x_noisy.shape[3]%2==0
469-
if not (divisible_by_2_h and divisible_by_2_w):
470-
#logger.warn(f"{x_noisy.shape} not divisible by 2!")
471-
new_h = (x_noisy.shape[2]//2)*2
472-
new_w = (x_noisy.shape[3]//2)*2
473-
if not divisible_by_2_h:
474-
new_h += 2
475-
if not divisible_by_2_w:
476-
new_w += 2
477-
self.latent_dims_div2 = (new_h, new_w)
478-
divisible_by_4_h = x_noisy.shape[2]%4==0
479-
divisible_by_4_w = x_noisy.shape[3]%4==0
480-
if not (divisible_by_4_h and divisible_by_4_w):
481-
#logger.warn(f"{x_noisy.shape} not divisible by 4!")
482-
new_h = (x_noisy.shape[2]//4)*4
483-
new_w = (x_noisy.shape[3]//4)*4
484-
if not divisible_by_4_h:
485-
new_h += 4
486-
if not divisible_by_4_w:
487-
new_w += 4
488-
self.latent_dims_div4 = (new_h, new_w)
489-
# prepare mask
490-
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
491-
# done preparing; model patches will take care of everything now.
492-
# return normal controlnet stuff
493-
return control_prev
494-
495-
def cleanup_advanced(self):
496-
super().cleanup_advanced()
497-
self.patch_attn1.cleanup()
498-
self.patch_attn2.cleanup()
499-
self.latent_dims_div2 = None
500-
self.latent_dims_div4 = None
501-
502-
def copy(self):
503-
c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes)
504-
self.copy_to(c)
505-
self.copy_to_advanced(c)
506-
return c
507-
508-
# deepcopy needs to properly keep track of objects to work between model.clone calls!
509-
# def __deepcopy__(self, *args, **kwargs):
510-
# self.cleanup_advanced()
511-
# return self
512-
513-
# def get_models(self):
514-
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
515-
# out = super().get_models()
516-
# logger.error(f"in get_models! {id(self)}")
517-
# return out
518-
519-
520420
def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
521421
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
522422
# from pathlib import Path
@@ -591,6 +491,112 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
591491
return control
592492

593493

494+
def convert_all_to_advanced(conds: list[list[dict[str]]]) -> tuple[bool, list]:
495+
cache = {}
496+
modified = False
497+
new_conds = []
498+
for cond in conds:
499+
converted_cond = None
500+
if cond is not None:
501+
need_to_convert = False
502+
# first, check if there is even a need to convert
503+
for sub_cond in cond:
504+
actual_cond = sub_cond[1]
505+
if "control" in actual_cond:
506+
if not are_all_advanced_controlnet(actual_cond["control"]):
507+
need_to_convert = True
508+
break
509+
if not need_to_convert:
510+
converted_cond = cond
511+
else:
512+
converted_cond = []
513+
for sub_cond in cond:
514+
new_sub_cond: list = []
515+
for actual_cond in sub_cond:
516+
if not type(actual_cond) == dict:
517+
new_sub_cond.append(actual_cond)
518+
continue
519+
if "control" not in actual_cond:
520+
new_sub_cond.append(actual_cond)
521+
elif are_all_advanced_controlnet(actual_cond["control"]):
522+
new_sub_cond.append(actual_cond)
523+
else:
524+
actual_cond = actual_cond.copy()
525+
actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache)
526+
new_sub_cond.append(actual_cond)
527+
modified = True
528+
converted_cond.append(new_sub_cond)
529+
new_conds.append(converted_cond)
530+
return modified, new_conds
531+
532+
533+
def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict):
534+
output_object = input_object
535+
# iteratively convert to advanced, if needed
536+
next_cn = None
537+
curr_cn = input_object
538+
iter = 0
539+
while curr_cn is not None:
540+
if not is_advanced_controlnet(curr_cn):
541+
# if already in cache, then conversion was done before, so just link it and exit
542+
if curr_cn in cache:
543+
new_cn = cache[curr_cn]
544+
if next_cn is not None:
545+
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
546+
next_cn.previous_controlnet = new_cn
547+
if iter == 0: # if was top-level controlnet, that's the new output
548+
output_object = new_cn
549+
break
550+
try:
551+
# convert to advanced, and assign previous_controlnet (convert doesn't transfer it)
552+
new_cn = convert_to_advanced(curr_cn)
553+
except Exception as e:
554+
raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e)
555+
new_cn.previous_controlnet = curr_cn.previous_controlnet
556+
if iter == 0: # if was top-level controlnet, that's the new output
557+
output_object = new_cn
558+
# if next_cn is present, then it needs to be pointed to new_cn
559+
if next_cn is not None:
560+
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
561+
next_cn.previous_controlnet = new_cn
562+
# add to cache
563+
cache[curr_cn] = new_cn
564+
curr_cn = new_cn
565+
next_cn = curr_cn
566+
curr_cn = curr_cn.previous_controlnet
567+
iter += 1
568+
return output_object
569+
570+
571+
def restore_all_controlnet_conns(conds: list[list[dict[str]]]):
572+
# if a cn has an _orig_previous_controlnet property, restore it and delete
573+
for main_cond in conds:
574+
if main_cond is not None:
575+
for cond in main_cond:
576+
if "control" in cond[1]:
577+
_restore_all_controlnet_conns(cond[1]["control"])
578+
579+
580+
def _restore_all_controlnet_conns(input_object: ControlBase):
581+
# restore original previous_controlnet if needed
582+
curr_cn = input_object
583+
while curr_cn is not None:
584+
if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET):
585+
curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
586+
delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
587+
curr_cn = curr_cn.previous_controlnet
588+
589+
590+
def are_all_advanced_controlnet(input_object: ControlBase):
591+
# iteratively check if linked controlnets objects are all advanced
592+
curr_cn = input_object
593+
while curr_cn is not None:
594+
if not is_advanced_controlnet(curr_cn):
595+
return False
596+
curr_cn = curr_cn.previous_controlnet
597+
return True
598+
599+
594600
def is_advanced_controlnet(input_object):
595601
return hasattr(input_object, "sub_idxs")
596602

@@ -749,55 +755,6 @@ class WeightsLoader(torch.nn.Module):
749755
return control
750756

751757

752-
def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
753-
if controlnet_data is None:
754-
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
755-
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
756-
# first, split weights for each module
757-
module_weights = {}
758-
for key, value in controlnet_data.items():
759-
fragments = key.split(".")
760-
module_name = fragments[0]
761-
weight_name = ".".join(fragments[1:])
762-
763-
if module_name not in module_weights:
764-
module_weights[module_name] = {}
765-
module_weights[module_name][weight_name] = value
766-
767-
# next, load each module
768-
modules = {}
769-
for module_name, weights in module_weights.items():
770-
# kohya planned to do something about how these should be chosen, so I'm not touching this
771-
# since I am not familiar with the logic for this
772-
if "conditioning1.4.weight" in weights:
773-
depth = 3
774-
elif weights["conditioning1.2.weight"].shape[-1] == 4:
775-
depth = 2
776-
else:
777-
depth = 1
778-
779-
module = LLLiteModule(
780-
name=module_name,
781-
is_conv2d=weights["down.0.weight"].ndim == 4,
782-
in_dim=weights["down.0.weight"].shape[1],
783-
depth=depth,
784-
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
785-
mlp_dim=weights["down.0.weight"].shape[0],
786-
)
787-
# load weights into module
788-
module.load_state_dict(weights)
789-
modules[module_name] = module
790-
if len(modules) == 1:
791-
module.is_first = True
792-
793-
#logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
794-
795-
patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
796-
patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
797-
control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe)
798-
return control
799-
800-
801758
def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
802759
if controlnet_data is None:
803760
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)

0 commit comments

Comments
 (0)