21
21
22
22
23
23
class ControlNetAdvanced (ControlNet , AdvancedControlBase ):
24
- def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , compression_ratio = 8 , latent_format = None , device = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT ):
25
- super ().__init__ (control_model = control_model , global_average_pooling = global_average_pooling , compression_ratio = compression_ratio , latent_format = latent_format , device = device , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
24
+ def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , compression_ratio = 8 , latent_format = None , load_device = None , manual_cast_dtype = None , extra_conds = ["y" ], strength_type = StrengthType .CONSTANT ):
25
+ super ().__init__ (control_model = control_model , global_average_pooling = global_average_pooling , compression_ratio = compression_ratio , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
26
26
AdvancedControlBase .__init__ (self , super (), timestep_keyframes = timestep_keyframes , weights_default = ControlWeights .controlnet ())
27
27
self .is_flux = False
28
28
self .x_noisy_shape = None
@@ -82,7 +82,7 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
82
82
comfy .model_management .load_models_gpu (loaded_models )
83
83
if self .latent_format is not None :
84
84
self .cond_hint = self .latent_format .process_in (self .cond_hint )
85
- self .cond_hint = self .cond_hint .to (device = self .device , dtype = dtype )
85
+ self .cond_hint = self .cond_hint .to (device = x_noisy .device , dtype = dtype )
86
86
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
87
87
self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
88
88
@@ -126,7 +126,7 @@ def cleanup_advanced(self):
126
126
@staticmethod
127
127
def from_vanilla (v : ControlNet , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'ControlNetAdvanced' :
128
128
to_return = ControlNetAdvanced (control_model = v .control_model , timestep_keyframes = timestep_keyframe ,
129
- global_average_pooling = v .global_average_pooling , compression_ratio = v .compression_ratio , latent_format = v .latent_format , device = v . device , load_device = v .load_device ,
129
+ global_average_pooling = v .global_average_pooling , compression_ratio = v .compression_ratio , latent_format = v .latent_format , load_device = v .load_device ,
130
130
manual_cast_dtype = v .manual_cast_dtype )
131
131
v .copy_to (to_return )
132
132
return to_return
@@ -213,8 +213,8 @@ def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -
213
213
214
214
215
215
class ControlLoraAdvanced (ControlLora , AdvancedControlBase ):
216
- def __init__ (self , control_weights , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , device = None ):
217
- super ().__init__ (control_weights = control_weights , global_average_pooling = global_average_pooling , device = device )
216
+ def __init__ (self , control_weights , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False ):
217
+ super ().__init__ (control_weights = control_weights , global_average_pooling = global_average_pooling )
218
218
AdvancedControlBase .__init__ (self , super (), timestep_keyframes = timestep_keyframes , weights_default = ControlWeights .controllora ())
219
219
# use some functions from ControlNetAdvanced
220
220
self .get_control_advanced = ControlNetAdvanced .get_control_advanced .__get__ (self , type (self ))
@@ -237,14 +237,14 @@ def cleanup(self):
237
237
@staticmethod
238
238
def from_vanilla (v : ControlLora , timestep_keyframe : TimestepKeyframeGroup = None ) -> 'ControlLoraAdvanced' :
239
239
to_return = ControlLoraAdvanced (control_weights = v .control_weights , timestep_keyframes = timestep_keyframe ,
240
- global_average_pooling = v .global_average_pooling , device = v . device )
240
+ global_average_pooling = v .global_average_pooling )
241
241
v .copy_to (to_return )
242
242
return to_return
243
243
244
244
245
245
class SVDControlNetAdvanced (ControlNetAdvanced ):
246
- def __init__ (self , control_model : SVDControlNet , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , device = None , load_device = None , manual_cast_dtype = None ):
247
- super ().__init__ (control_model = control_model , timestep_keyframes = timestep_keyframes , global_average_pooling = global_average_pooling , device = device , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
246
+ def __init__ (self , control_model : SVDControlNet , timestep_keyframes : TimestepKeyframeGroup , global_average_pooling = False , load_device = None , manual_cast_dtype = None ):
247
+ super ().__init__ (control_model = control_model , timestep_keyframes = timestep_keyframes , global_average_pooling = global_average_pooling , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
248
248
249
249
def set_cond_hint_inject (self , * args , ** kwargs ):
250
250
to_return = super ().set_cond_hint_inject (* args , ** kwargs )
@@ -280,9 +280,9 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
280
280
actual_cond_hint_orig = self .cond_hint_original
281
281
if self .cond_hint_original .size (0 ) < self .full_latent_length :
282
282
actual_cond_hint_orig = extend_to_batch_size (tensor = actual_cond_hint_orig , batch_size = self .full_latent_length )
283
- self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
283
+ self .cond_hint = comfy .utils .common_upscale (actual_cond_hint_orig [self .sub_idxs ], x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (x_noisy .device )
284
284
else :
285
- self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
285
+ self .cond_hint = comfy .utils .common_upscale (self .cond_hint_original , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (x_noisy .device )
286
286
if x_noisy .shape [0 ] != self .cond_hint .shape [0 ]:
287
287
self .cond_hint = broadcast_image_to_extend (self .cond_hint , x_noisy .shape [0 ], batched_number )
288
288
@@ -311,8 +311,8 @@ def copy(self):
311
311
312
312
313
313
class SparseCtrlAdvanced (ControlNetAdvanced ):
314
- def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , sparse_settings : SparseSettings = None , global_average_pooling = False , device = None , load_device = None , manual_cast_dtype = None ):
315
- super ().__init__ (control_model = control_model , timestep_keyframes = timestep_keyframes , global_average_pooling = global_average_pooling , device = device , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
314
+ def __init__ (self , control_model , timestep_keyframes : TimestepKeyframeGroup , sparse_settings : SparseSettings = None , global_average_pooling = False , load_device = None , manual_cast_dtype = None ):
315
+ super ().__init__ (control_model = control_model , timestep_keyframes = timestep_keyframes , global_average_pooling = global_average_pooling , load_device = load_device , manual_cast_dtype = manual_cast_dtype )
316
316
self .control_model_wrapped = SparseModelPatcher (self .control_model , load_device = load_device , offload_device = comfy .model_management .unet_offload_device ())
317
317
self .add_compatible_weight (ControlWeightType .SPARSECTRL )
318
318
self .control_model : SparseControlNet = self .control_model # does nothing except help with IDE hints
@@ -377,25 +377,25 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
377
377
self .local_sparse_idxs_inverse .remove (actual_i )
378
378
# sub_cond_hint now contains the hints relevant to current x_noisy
379
379
if hint_order is None :
380
- sub_cond_hint = self .cond_hint_original [hint_idxs ].to (dtype ).to (self .device )
380
+ sub_cond_hint = self .cond_hint_original [hint_idxs ].to (dtype ).to (x_noisy .device )
381
381
else :
382
- sub_cond_hint = self .cond_hint_original [hint_order ][hint_idxs ].to (dtype ).to (self .device )
382
+ sub_cond_hint = self .cond_hint_original [hint_order ][hint_idxs ].to (dtype ).to (x_noisy .device )
383
383
# scale cond_hints to match noisy input
384
384
if self .control_model .use_simplified_conditioning_embedding :
385
385
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
386
386
sub_cond_hint = self .model_latent_format .process_in (sub_cond_hint ) # multiplies by model scale factor
387
- sub_cond_hint = comfy .utils .common_upscale (sub_cond_hint , x_noisy .shape [3 ], x_noisy .shape [2 ], "nearest-exact" , "center" ).to (dtype ).to (self .device )
387
+ sub_cond_hint = comfy .utils .common_upscale (sub_cond_hint , x_noisy .shape [3 ], x_noisy .shape [2 ], "nearest-exact" , "center" ).to (dtype ).to (x_noisy .device )
388
388
else :
389
389
# other SparseCtrl; inputs are typical images
390
- sub_cond_hint = comfy .utils .common_upscale (sub_cond_hint , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (self .device )
390
+ sub_cond_hint = comfy .utils .common_upscale (sub_cond_hint , x_noisy .shape [3 ] * 8 , x_noisy .shape [2 ] * 8 , 'nearest-exact' , "center" ).to (dtype ).to (x_noisy .device )
391
391
# prepare cond_hint (b, c, h ,w)
392
392
cond_shape = list (sub_cond_hint .shape )
393
393
cond_shape [0 ] = len (range_idxs )
394
- self .cond_hint = torch .zeros (cond_shape ).to (dtype ).to (self .device )
394
+ self .cond_hint = torch .zeros (cond_shape ).to (dtype ).to (x_noisy .device )
395
395
self .cond_hint [local_idxs ] = sub_cond_hint [:]
396
396
# prepare cond_mask (b, 1, h, w)
397
397
cond_shape [1 ] = 1
398
- cond_mask = torch .zeros (cond_shape ).to (dtype ).to (self .device )
398
+ cond_mask = torch .zeros (cond_shape ).to (dtype ).to (x_noisy .device )
399
399
cond_mask [local_idxs ] = self .sparse_settings .sparse_mask_mult * self .weights .extras .get (SparseConst .MASK_MULT , 1.0 )
400
400
# combine cond_hint and cond_mask into (b, c+1, h, w)
401
401
if not self .sparse_settings .merged :
@@ -446,7 +446,7 @@ def cleanup_advanced(self):
446
446
self .local_sparse_idxs_inverse = None
447
447
448
448
def copy (self ):
449
- c = SparseCtrlAdvanced (self .control_model , self .timestep_keyframes , self .sparse_settings , self .global_average_pooling , self .device , self . load_device , self .manual_cast_dtype )
449
+ c = SparseCtrlAdvanced (self .control_model , self .timestep_keyframes , self .sparse_settings , self .global_average_pooling , self .load_device , self .manual_cast_dtype )
450
450
self .copy_to (c )
451
451
self .copy_to_advanced (c )
452
452
return c
0 commit comments