3
3
import torch
4
4
import os
5
5
6
+ import comfy .ops
6
7
import comfy .utils
7
8
import comfy .model_management
8
9
import comfy .model_detection
9
10
import comfy .controlnet as comfy_cn
10
- from comfy .controlnet import ControlBase , ControlNet , ControlLora , T2IAdapter , broadcast_image_to
11
+ from comfy .controlnet import ControlBase , ControlNet , ControlLora , T2IAdapter
11
12
from comfy .model_patcher import ModelPatcher
12
13
13
- from .control_sparsectrl import SparseModelPatcher , SparseControlNet , SparseCtrlMotionWrapper , SparseMethod , SparseSettings , SparseSpreadMethod , PreprocSparseRGBWrapper
14
+ from .control_sparsectrl import SparseModelPatcher , SparseControlNet , SparseCtrlMotionWrapper , SparseMethod , SparseSettings , SparseSpreadMethod , PreprocSparseRGBWrapper , SparseConst
14
15
from .control_lllite import LLLiteModule , LLLitePatch
15
16
from .control_svd import svd_unet_config_from_diffusers_unet , SVDControlNet , svd_unet_to_diffusers
16
17
from .utils import (AdvancedControlBase , TimestepKeyframeGroup , LatentKeyframeGroup , ControlWeightType , ControlWeights , WeightTypeException ,
17
- manual_cast_clean_groupnorm , disable_weight_init_clean_groupnorm , prepare_mask_batch , get_properly_arranged_t2i_weights , load_torch_file_with_dict_factory )
18
+ manual_cast_clean_groupnorm , disable_weight_init_clean_groupnorm , prepare_mask_batch , get_properly_arranged_t2i_weights , load_torch_file_with_dict_factory ,
19
+ broadcast_image_to_extend , extend_to_batch_size )
18
20
from .logger import logger
19
21
20
22
@@ -55,12 +57,15 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
55
57
del self .cond_hint
56
58
self .cond_hint = None
57
59
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
58
- if self .sub_idxs is not None and self .cond_hint_original .size (0 ) >= self .full_latent_length :
59
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
60
+ if self .sub_idxs is not None :
61
+ actual_cond_hint_orig = self .cond_hint_original
62
+ if self .cond_hint_original .size (0 ) < self .full_latent_length :
63
+ actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
64
+ self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
60
65
else :
61
66
self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
62
67
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
63
- self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
68
+ self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
64
69
65
70
# prepare mask_cond_hint
66
71
self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number , dtype = dtype )
@@ -97,7 +102,7 @@ def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channel
97
102
98
103
def control_merge_inject (self , control_input , control_output , control_prev , output_dtype ):
99
104
# if has uncond multiplier, need to make sure control shapes are the same batch size as expected
100
- if self .weights .has_uncond_multiplier :
105
+ if self .weights .has_uncond_multiplier or self . weights . has_uncond_mask :
101
106
if control_input is not None :
102
107
for i in range (len (control_input )):
103
108
x = control_input [i ]
@@ -131,9 +136,12 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
131
136
if self .sub_idxs is not None :
132
137
# cond hints
133
138
full_cond_hint_original = self .cond_hint_original
139
+ actual_cond_hint_orig = full_cond_hint_original
134
140
del self .cond_hint
135
141
self .cond_hint = None
136
- self .cond_hint_original = full_cond_hint_original [self .sub_idxs ]
142
+ if full_cond_hint_original .size (0 ) < self .full_latent_length :
143
+ actual_cond_hint_orig = extend_to_batch_size (tensor = full_cond_hint_original , batch_size = full_cond_hint_original .size (0 ))
144
+ self .cond_hint_original = actual_cond_hint_orig [self .sub_idxs ]
137
145
# mask hints
138
146
self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number )
139
147
return super ().get_control (x_noisy , t , cond , batched_number )
@@ -221,12 +229,15 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
221
229
del self .cond_hint
222
230
self .cond_hint = None
223
231
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
224
- if self .sub_idxs is not None and self .cond_hint_original .size (0 ) >= self .full_latent_length :
225
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
232
+ if self .sub_idxs is not None :
233
+ actual_cond_hint_orig = self .cond_hint_original
234
+ if self .cond_hint_original .size (0 ) < self .full_latent_length :
235
+ actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
236
+ self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
226
237
else :
227
238
self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
228
239
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
229
- self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
240
+ self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
230
241
231
242
# prepare mask_cond_hint
232
243
self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number , dtype = dtype )
@@ -291,18 +302,33 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
291
302
del self .cond_hint
292
303
self .cond_hint = None
293
304
# first, figure out which cond idxs are relevant, and where they fit in
294
- cond_idxs = self .sparse_settings .sparse_method .get_indexes (hint_length = self .cond_hint_original .size (0 ), full_length = full_length )
295
-
305
+ cond_idxs , hint_order = self .sparse_settings .sparse_method .get_indexes (hint_length = self .cond_hint_original .size (0 ), full_length = full_length ,
306
+ sub_idxs = self . sub_idxs if self . sparse_settings . is_context_aware () else None )
296
307
range_idxs = list (range (full_length )) if self .sub_idxs is None else self .sub_idxs
297
308
hint_idxs = [] # idxs in cond_idxs
298
- local_idxs = [] # idx to pun in final cond_hint
309
+ local_idxs = [] # idx to put in final cond_hint
299
310
for i ,cond_idx in enumerate (cond_idxs ):
300
311
if cond_idx in range_idxs :
301
312
hint_idxs .append (i )
302
313
local_idxs .append (range_idxs .index (cond_idx ))
314
+ # log_string = f"cond_idxs: {cond_idxs}, local_idxs: {local_idxs}, hint_idxs: {hint_idxs}, hint_order: {hint_order}"
315
+ # if self.sub_idxs is not None:
316
+ # log_string += f" sub_idxs: {self.sub_idxs[0]}-{self.sub_idxs[-1]}"
317
+ # logger.warn(log_string)
318
+ # determine cond/uncond indexes that will get masked
319
+ self .local_sparse_idxs = []
320
+ self .local_sparse_idxs_inverse = list (range (x_noisy .size (0 )))
321
+ for batch_idx in range (batched_number ):
322
+ for i in local_idxs :
323
+ actual_i = i + (batch_idx * actual_length )
324
+ self .local_sparse_idxs .append (actual_i )
325
+ if actual_i in self .local_sparse_idxs_inverse :
326
+ self .local_sparse_idxs_inverse .remove (actual_i )
303
327
# sub_cond_hint now contains the hints relevant to current x_noisy
304
- sub_cond_hint = self .cond_hint_original [hint_idxs ].to (dtype ).to (self .device )
305
-
328
+ if hint_order is None :
329
+ sub_cond_hint = self .cond_hint_original [hint_idxs ].to (dtype ).to (self .device )
330
+ else :
331
+ sub_cond_hint = self .cond_hint_original [hint_order ][hint_idxs ].to (dtype ).to (self .device )
306
332
# scale cond_hints to match noisy input
307
333
if self .control_model .use_simplified_conditioning_embedding :
308
334
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
@@ -319,15 +345,15 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
319
345
# prepare cond_mask (b, 1, h, w)
320
346
cond_shape [1 ] = 1
321
347
cond_mask = torch .zeros (cond_shape ).to (dtype ).to (self .device )
322
- cond_mask [local_idxs ] = 1.0
348
+ cond_mask [local_idxs ] = self . sparse_settings . sparse_mask_mult * self . weights . extras . get ( SparseConst . MASK_MULT , 1.0 )
323
349
# combine cond_hint and cond_mask into (b, c+1, h, w)
324
350
if not self .sparse_settings .merged :
325
351
self .cond_hint = torch .cat ([self .cond_hint , cond_mask ], dim = 1 )
326
352
del sub_cond_hint
327
353
del cond_mask
328
354
# make cond_hint match x_noisy batch
329
355
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
330
- self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
356
+ self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
331
357
332
358
# prepare mask_cond_hint
333
359
self .prepare_mask_cond_hint (x_noisy = x_noisy , t = t , cond = cond , batched_number = batched_number , dtype = dtype )
@@ -342,6 +368,12 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
342
368
control = self .control_model (x = x_noisy .to (dtype ), hint = self .cond_hint , timesteps = timestep .float (), context = context .to (dtype ), y = y )
343
369
return self .control_merge (None , control , control_prev , output_dtype )
344
370
371
+ def apply_advanced_strengths_and_masks (self , x : Tensor , batched_number : int ):
372
+ # apply mults to indexes with and without a direct condhint
373
+ x [self .local_sparse_idxs ] *= self .sparse_settings .sparse_hint_mult * self .weights .extras .get (SparseConst .HINT_MULT , 1.0 )
374
+ x [self .local_sparse_idxs_inverse ] *= self .sparse_settings .sparse_nonhint_mult * self .weights .extras .get (SparseConst .NONHINT_MULT , 1.0 )
375
+ return super ().apply_advanced_strengths_and_masks (x , batched_number )
376
+
345
377
def pre_run_advanced (self , model , percent_to_timestep_function ):
346
378
super ().pre_run_advanced (model , percent_to_timestep_function )
347
379
if type (self .cond_hint_original ) == PreprocSparseRGBWrapper :
@@ -359,6 +391,8 @@ def cleanup_advanced(self):
359
391
if self .latent_format is not None :
360
392
del self .latent_format
361
393
self .latent_format = None
394
+ self .local_sparse_idxs = None
395
+ self .local_sparse_idxs_inverse = None
362
396
363
397
def copy (self ):
364
398
c = SparseCtrlAdvanced (self .control_model , self .timestep_keyframes , self .sparse_settings , self .global_average_pooling , self .device , self .load_device , self .manual_cast_dtype )
@@ -411,12 +445,15 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
411
445
del self .cond_hint
412
446
self .cond_hint = None
413
447
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
414
- if self .sub_idxs is not None and self .cond_hint_original .size (0 ) >= self .full_latent_length :
415
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
448
+ if self .sub_idxs is not None :
449
+ actual_cond_hint_orig = self .cond_hint_original
450
+ if self .cond_hint_original .size (0 ) < self .full_latent_length :
451
+ actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
452
+ self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
416
453
else :
417
454
self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
418
455
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
419
- self .cond_hint = broadcast_image_to (self .cond_hint , x_noisy .shape [0 ], batched_number )
456
+ self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
420
457
# some special logic here compared to other controlnets:
421
458
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
422
459
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
@@ -551,13 +588,9 @@ def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, tim
551
588
motion_data [key ] = controlnet_data .pop (key )
552
589
if len (motion_data ) == 0 :
553
590
raise ValueError (f"No motion-related keys in '{ ckpt_path } '; not a valid SparseCtrl model!" )
554
- motion_wrapper : SparseCtrlMotionWrapper = SparseCtrlMotionWrapper (motion_data ).to (comfy .model_management .unet_dtype ())
555
- missing , unexpected = motion_wrapper .load_state_dict (motion_data )
556
- if len (missing ) > 0 or len (unexpected ) > 0 :
557
- logger .info (f"SparseCtrlMotionWrapper: { missing } , { unexpected } " )
558
591
559
592
# now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
560
- controlnet_config = None
593
+ controlnet_config : dict [ str ] = None
561
594
is_diffusers = False
562
595
use_simplified_conditioning_embedding = False
563
596
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data :
@@ -681,6 +714,12 @@ class WeightsLoader(torch.nn.Module):
681
714
if filename .endswith ("_shuffle" ) or filename .endswith ("_shuffle_fp16" ): #TODO: smarter way of enabling global_average_pooling
682
715
global_average_pooling = True
683
716
717
+ # actually load motion portion of model now
718
+ motion_wrapper : SparseCtrlMotionWrapper = SparseCtrlMotionWrapper (motion_data , ops = controlnet_config .get ("operations" , None )).to (comfy .model_management .unet_dtype ())
719
+ missing , unexpected = motion_wrapper .load_state_dict (motion_data )
720
+ if len (missing ) > 0 or len (unexpected ) > 0 :
721
+ logger .info (f"SparseCtrlMotionWrapper: { missing } , { unexpected } " )
722
+
684
723
# both motion portion and controlnet portions are loaded; bring them together if using motion model
685
724
if sparse_settings .use_motion :
686
725
motion_wrapper .inject (control_model )
0 commit comments