12
12
from comfy .model_patcher import ModelPatcher
13
13
14
14
from .control_sparsectrl import SparseModelPatcher , SparseControlNet , SparseCtrlMotionWrapper , SparseSettings , SparseConst
15
- from .control_lllite import LLLiteModule , LLLitePatch
15
+ from .control_lllite import LLLiteModule , LLLitePatch , load_controllllite
16
16
from .control_svd import svd_unet_config_from_diffusers_unet , SVDControlNet , svd_unet_to_diffusers
17
17
from .utils import (AdvancedControlBase , TimestepKeyframeGroup , LatentKeyframeGroup , AbstractPreprocWrapper , ControlWeightType , ControlWeights , WeightTypeException ,
18
18
manual_cast_clean_groupnorm , disable_weight_init_clean_groupnorm , prepare_mask_batch , get_properly_arranged_t2i_weights , load_torch_file_with_dict_factory ,
19
19
broadcast_image_to_extend , extend_to_batch_size )
20
20
from .logger import logger
21
21
22
22
23
+ ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"
24
+
25
+
23
26
class ControlNetAdvanced (ControlNet , AdvancedControlBase ):
24
27
def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None ):
25
28
super ().__init__ (control_model = control_model , global_average_pooling = global_average_pooling , compression_ratio = compression_ratio , latent_format = latent_format , device = device , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
@@ -98,8 +101,10 @@ def copy(self):
98
101
99
102
@staticmethod
100
103
def from_vanilla (v : ControlNet , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'ControlNetAdvanced' :
101
- return ControlNetAdvanced (control_model = v .control_model , timestep_keyframes = timestep_keyframe ,
104
+ to_return = ControlNetAdvanced (control_model = v .control_model , timestep_keyframes = timestep_keyframe ,
102
105
global_average_pooling = v .global_average_pooling , compression_ratio = v .compression_ratio , latent_format = v .latent_format , device = v .device , load_device = v .load_device , manual_cast_dtype = v .manual_cast_dtype )
106
+ v .copy_to (to_return )
107
+ return to_return
103
108
104
109
105
110
class T2IAdapterAdvanced (T2IAdapter , AdvancedControlBase ):
@@ -166,8 +171,10 @@ def cleanup(self):
166
171
167
172
@staticmethod
168
173
def from_vanilla (v : T2IAdapter , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'T2IAdapterAdvanced' :
169
- return T2IAdapterAdvanced (t2i_model = v .t2i_model , timestep_keyframes = timestep_keyframe , channels_in = v .channels_in ,
174
+ to_return = T2IAdapterAdvanced (t2i_model = v .t2i_model , timestep_keyframes = timestep_keyframe , channels_in = v .channels_in ,
170
175
compression_ratio = v .compression_ratio , upscale_algorithm = v .upscale_algorithm , device = v .device )
176
+ v .copy_to (to_return )
177
+ return to_return
171
178
172
179
173
180
class ControlLoraAdvanced (ControlLora , AdvancedControlBase ):
@@ -194,8 +201,10 @@ def cleanup(self):
194
201
195
202
@staticmethod
196
203
def from_vanilla (v : ControlLora , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'ControlLoraAdvanced' :
197
- return ControlLoraAdvanced (control_weights = v .control_weights , timestep_keyframes = timestep_keyframe ,
204
+ to_return = ControlLoraAdvanced (control_weights = v .control_weights , timestep_keyframes = timestep_keyframe ,
198
205
global_average_pooling = v .global_average_pooling , device = v .device )
206
+ v .copy_to (to_return )
207
+ return to_return
199
208
200
209
201
210
class SVDControlNetAdvanced (ControlNetAdvanced ):
@@ -408,115 +417,6 @@ def copy(self):
408
417
return c
409
418
410
419
411
- class ControlLLLiteAdvanced (ControlBase , AdvancedControlBase ):
412
- # This ControlNet is more of an attention patch than a traditional controlnet
413
- def __init__ (self , patch_attn1 : LLLitePatch , patch_attn2 : LLLitePatch , timestep_keyframes : TimestepKeyframeGroup , device = None ):
414
- super ().__init__ (device )
415
- AdvancedControlBase .__init__ (self , super (), timestep_keyframes = timestep_keyframes , weights_default = ControlWeights .controllllite (), require_model = True )
416
- self .patch_attn1 = patch_attn1 .set_control (self )
417
- self .patch_attn2 = patch_attn2 .set_control (self )
418
- self .latent_dims_div2 = None
419
- self .latent_dims_div4 = None
420
-
421
- def patch_model (self , model : ModelPatcher ):
422
- model .set_model_attn1_patch (self .patch_attn1 )
423
- model .set_model_attn2_patch (self .patch_attn2 )
424
-
425
- def set_cond_hint_inject (self , * args , ** kwargs ):
426
- to_return = super ().set_cond_hint_inject (* args , ** kwargs )
427
- # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
428
- self .cond_hint_original = self .cond_hint_original * 2.0 - 1.0
429
- return to_return
430
-
431
- def pre_run_advanced (self , * args , ** kwargs ):
432
- AdvancedControlBase .pre_run_advanced (self , * args , ** kwargs )
433
- #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
434
- self .patch_attn1 .set_control (self )
435
- self .patch_attn2 .set_control (self )
436
- #logger.warn(f"in pre_run_advanced: {id(self)}")
437
-
438
- def get_control_advanced (self , x_noisy : Tensor , t , cond , batched_number : int ):
439
- # normal ControlNet stuff
440
- control_prev = None
441
- if self .previous_controlnet is not None :
442
- control_prev = self .previous_controlnet .get_control (x_noisy , t , cond , batched_number )
443
-
444
- if self .timestep_range is not None :
445
- if t [0 ] > self .timestep_range [0 ] or t [0 ] < self .timestep_range [1 ]:
446
- return control_prev
447
-
448
- dtype = x_noisy .dtype
449
- # prepare cond_hint
450
- if self .sub_idxs is not None or self .cond_hint is None or x_noisy .shape [2 ] * 8 != self .cond_hint .shape [2 ] or x_noisy .shape [3 ] * 8 != self .cond_hint .shape [3 ]:
451
- if self .cond_hint is not None :
452
- del self .cond_hint
453
- self .cond_hint = None
454
- # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
455
- if self .sub_idxs is not None :
456
- actual_cond_hint_orig = self .cond_hint_original
457
- if self .cond_hint_original .size (0 ) < self .full_latent_length :
458
- actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
459
- self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
460
- else :
461
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
462
- if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
463
- self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
464
- # some special logic here compared to other controlnets:
465
- # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
466
- # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
467
- divisible_by_2_h = x_noisy .shape [2 ]% 2 == 0
468
- divisible_by_2_w = x_noisy .shape [3 ]% 2 == 0
469
- if not (divisible_by_2_h and divisible_by_2_w ):
470
- #logger.warn(f"{x_noisy.shape} not divisible by 2!")
471
- new_h = (x_noisy .shape [2 ]// 2 )* 2
472
- new_w = (x_noisy .shape [3 ]// 2 )* 2
473
- if not divisible_by_2_h :
474
- new_h += 2
475
- if not divisible_by_2_w :
476
- new_w += 2
477
- self .latent_dims_div2 = (new_h , new_w )
478
- divisible_by_4_h = x_noisy .shape [2 ]% 4 == 0
479
- divisible_by_4_w = x_noisy .shape [3 ]% 4 == 0
480
- if not (divisible_by_4_h and divisible_by_4_w ):
481
- #logger.warn(f"{x_noisy.shape} not divisible by 4!")
482
- new_h = (x_noisy .shape [2 ]// 4 )* 4
483
- new_w = (x_noisy .shape [3 ]// 4 )* 4
484
- if not divisible_by_4_h :
485
- new_h += 4
486
- if not divisible_by_4_w :
487
- new_w += 4
488
- self .latent_dims_div4 = (new_h , new_w )
489
- # prepare mask
490
- self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number )
491
- # done preparing; model patches will take care of everything now.
492
- # return normal controlnet stuff
493
- return control_prev
494
-
495
- def cleanup_advanced (self ):
496
- super ().cleanup_advanced ()
497
- self .patch_attn1 .cleanup ()
498
- self .patch_attn2 .cleanup ()
499
- self .latent_dims_div2 = None
500
- self .latent_dims_div4 = None
501
-
502
- def copy (self ):
503
- c = ControlLLLiteAdvanced (self .patch_attn1 , self .patch_attn2 , self .timestep_keyframes )
504
- self .copy_to (c )
505
- self .copy_to_advanced (c )
506
- return c
507
-
508
- # deepcopy needs to properly keep track of objects to work between model.clone calls!
509
- # def __deepcopy__(self, *args, **kwargs):
510
- # self.cleanup_advanced()
511
- # return self
512
-
513
- # def get_models(self):
514
- # # get_models is called once at the start of every KSampler run - use to reset already_patched status
515
- # out = super().get_models()
516
- # logger.error(f"in get_models! {id(self)}")
517
- # return out
518
-
519
-
520
420
def load_controlnet (ckpt_path , timestep_keyframe : TimestepKeyframeGroup = None , model = None ):
521
421
controlnet_data = comfy .utils .load_torch_file (ckpt_path , safe_load = True )
522
422
# from pathlib import Path
@@ -591,6 +491,112 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
591
491
return control
592
492
593
493
494
+ def convert_all_to_advanced (conds : list [list [dict [str ]]]) -> tuple [bool , list ]:
495
+ cache = {}
496
+ modified = False
497
+ new_conds = []
498
+ for cond in conds :
499
+ converted_cond = None
500
+ if cond is not None :
501
+ need_to_convert = False
502
+ # first, check if there is even a need to convert
503
+ for sub_cond in cond :
504
+ actual_cond = sub_cond [1 ]
505
+ if "control" in actual_cond :
506
+ if not are_all_advanced_controlnet (actual_cond ["control" ]):
507
+ need_to_convert = True
508
+ break
509
+ if not need_to_convert :
510
+ converted_cond = cond
511
+ else :
512
+ converted_cond = []
513
+ for sub_cond in cond :
514
+ new_sub_cond : list = []
515
+ for actual_cond in sub_cond :
516
+ if not type (actual_cond ) == dict :
517
+ new_sub_cond .append (actual_cond )
518
+ continue
519
+ if "control" not in actual_cond :
520
+ new_sub_cond .append (actual_cond )
521
+ elif are_all_advanced_controlnet (actual_cond ["control" ]):
522
+ new_sub_cond .append (actual_cond )
523
+ else :
524
+ actual_cond = actual_cond .copy ()
525
+ actual_cond ["control" ] = _convert_all_control_to_advanced (actual_cond ["control" ], cache )
526
+ new_sub_cond .append (actual_cond )
527
+ modified = True
528
+ converted_cond .append (new_sub_cond )
529
+ new_conds .append (converted_cond )
530
+ return modified , new_conds
531
+
532
+
533
+ def _convert_all_control_to_advanced (input_object : ControlBase , cache : dict ):
534
+ output_object = input_object
535
+ # iteratively convert to advanced, if needed
536
+ next_cn = None
537
+ curr_cn = input_object
538
+ iter = 0
539
+ while curr_cn is not None :
540
+ if not is_advanced_controlnet (curr_cn ):
541
+ # if already in cache, then conversion was done before, so just link it and exit
542
+ if curr_cn in cache :
543
+ new_cn = cache [curr_cn ]
544
+ if next_cn is not None :
545
+ setattr (next_cn , ORIG_PREVIOUS_CONTROLNET , next_cn .previous_controlnet )
546
+ next_cn .previous_controlnet = new_cn
547
+ if iter == 0 : # if was top-level controlnet, that's the new output
548
+ output_object = new_cn
549
+ break
550
+ try :
551
+ # convert to advanced, and assign previous_controlnet (convert doesn't transfer it)
552
+ new_cn = convert_to_advanced (curr_cn )
553
+ except Exception as e :
554
+ raise Exception ("Failed to automatically convert a ControlNet to Advanced to support sliding window context." , e )
555
+ new_cn .previous_controlnet = curr_cn .previous_controlnet
556
+ if iter == 0 : # if was top-level controlnet, that's the new output
557
+ output_object = new_cn
558
+ # if next_cn is present, then it needs to be pointed to new_cn
559
+ if next_cn is not None :
560
+ setattr (next_cn , ORIG_PREVIOUS_CONTROLNET , next_cn .previous_controlnet )
561
+ next_cn .previous_controlnet = new_cn
562
+ # add to cache
563
+ cache [curr_cn ] = new_cn
564
+ curr_cn = new_cn
565
+ next_cn = curr_cn
566
+ curr_cn = curr_cn .previous_controlnet
567
+ iter += 1
568
+ return output_object
569
+
570
+
571
+ def restore_all_controlnet_conns (conds : list [list [dict [str ]]]):
572
+ # if a cn has an _orig_previous_controlnet property, restore it and delete
573
+ for main_cond in conds :
574
+ if main_cond is not None :
575
+ for cond in main_cond :
576
+ if "control" in cond [1 ]:
577
+ _restore_all_controlnet_conns (cond [1 ]["control" ])
578
+
579
+
580
+ def _restore_all_controlnet_conns (input_object : ControlBase ):
581
+ # restore original previous_controlnet if needed
582
+ curr_cn = input_object
583
+ while curr_cn is not None :
584
+ if hasattr (curr_cn , ORIG_PREVIOUS_CONTROLNET ):
585
+ curr_cn .previous_controlnet = getattr (curr_cn , ORIG_PREVIOUS_CONTROLNET )
586
+ delattr (curr_cn , ORIG_PREVIOUS_CONTROLNET )
587
+ curr_cn = curr_cn .previous_controlnet
588
+
589
+
590
+ def are_all_advanced_controlnet (input_object : ControlBase ):
591
+ # iteratively check if linked controlnets objects are all advanced
592
+ curr_cn = input_object
593
+ while curr_cn is not None :
594
+ if not is_advanced_controlnet (curr_cn ):
595
+ return False
596
+ curr_cn = curr_cn .previous_controlnet
597
+ return True
598
+
599
+
594
600
def is_advanced_controlnet (input_object ):
595
601
return hasattr (input_object , "sub_idxs" )
596
602
@@ -749,55 +755,6 @@ class WeightsLoader(torch.nn.Module):
749
755
return control
750
756
751
757
752
- def load_controllllite (ckpt_path : str , controlnet_data : dict [str , Tensor ]= None , timestep_keyframe : TimestepKeyframeGroup = None ):
753
- if controlnet_data is None :
754
- controlnet_data = comfy .utils .load_torch_file (ckpt_path , safe_load = True )
755
- # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
756
- # first, split weights for each module
757
- module_weights = {}
758
- for key , value in controlnet_data .items ():
759
- fragments = key .split ("." )
760
- module_name = fragments [0 ]
761
- weight_name = "." .join (fragments [1 :])
762
-
763
- if module_name not in module_weights :
764
- module_weights [module_name ] = {}
765
- module_weights [module_name ][weight_name ] = value
766
-
767
- # next, load each module
768
- modules = {}
769
- for module_name , weights in module_weights .items ():
770
- # kohya planned to do something about how these should be chosen, so I'm not touching this
771
- # since I am not familiar with the logic for this
772
- if "conditioning1.4.weight" in weights :
773
- depth = 3
774
- elif weights ["conditioning1.2.weight" ].shape [- 1 ] == 4 :
775
- depth = 2
776
- else :
777
- depth = 1
778
-
779
- module = LLLiteModule (
780
- name = module_name ,
781
- is_conv2d = weights ["down.0.weight" ].ndim == 4 ,
782
- in_dim = weights ["down.0.weight" ].shape [1 ],
783
- depth = depth ,
784
- cond_emb_dim = weights ["conditioning1.0.weight" ].shape [0 ] * 2 ,
785
- mlp_dim = weights ["down.0.weight" ].shape [0 ],
786
- )
787
- # load weights into module
788
- module .load_state_dict (weights )
789
- modules [module_name ] = module
790
- if len (modules ) == 1 :
791
- module .is_first = True
792
-
793
- #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
794
-
795
- patch_attn1 = LLLitePatch (modules = modules , patch_type = LLLitePatch .ATTN1 )
796
- patch_attn2 = LLLitePatch (modules = modules , patch_type = LLLitePatch .ATTN2 )
797
- control = ControlLLLiteAdvanced (patch_attn1 = patch_attn1 , patch_attn2 = patch_attn2 , timestep_keyframes = timestep_keyframe )
798
- return control
799
-
800
-
801
758
def load_svdcontrolnet (ckpt_path : str , controlnet_data : dict [str , Tensor ]= None , timestep_keyframe : TimestepKeyframeGroup = None , model = None ):
802
759
if controlnet_data is None :
803
760
controlnet_data = comfy .utils .load_torch_file (ckpt_path , safe_load = True )
0 commit comments