Skip to content

Commit dc48641

Browse files
authored
Merge PR #188 from Kosinkadink/comfyupdate
Removed device var to match new ComfyUI update
2 parents 74d0c56 + 934f55d commit dc48641

6 files changed

+32
-32
lines changed

adv_control/control.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222

2323
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
24-
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
25-
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
24+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
25+
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
2626
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
2727
self.is_flux = False
2828
self.x_noisy_shape = None
@@ -82,7 +82,7 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
8282
comfy.model_management.load_models_gpu(loaded_models)
8383
if self.latent_format is not None:
8484
self.cond_hint = self.latent_format.process_in(self.cond_hint)
85-
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
85+
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
8686
if x_noisy.shape[0] != self.cond_hint.shape[0]:
8787
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
8888

@@ -126,7 +126,7 @@ def cleanup_advanced(self):
126126
@staticmethod
127127
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
128128
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
129-
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device,
129+
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, load_device=v.load_device,
130130
manual_cast_dtype=v.manual_cast_dtype)
131131
v.copy_to(to_return)
132132
return to_return
@@ -213,8 +213,8 @@ def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -
213213

214214

215215
class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
216-
def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None):
217-
super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling, device=device)
216+
def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False):
217+
super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling)
218218
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora())
219219
# use some functions from ControlNetAdvanced
220220
self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self))
@@ -237,14 +237,14 @@ def cleanup(self):
237237
@staticmethod
238238
def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
239239
to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
240-
global_average_pooling=v.global_average_pooling, device=v.device)
240+
global_average_pooling=v.global_average_pooling)
241241
v.copy_to(to_return)
242242
return to_return
243243

244244

245245
class SVDControlNetAdvanced(ControlNetAdvanced):
246-
def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
247-
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
246+
def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
247+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
248248

249249
def set_cond_hint_inject(self, *args, **kwargs):
250250
to_return = super().set_cond_hint_inject(*args, **kwargs)
@@ -280,9 +280,9 @@ def get_control_advanced(self, x_noisy, t, cond, batched_number):
280280
actual_cond_hint_orig = self.cond_hint_original
281281
if self.cond_hint_original.size(0) < self.full_latent_length:
282282
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
283-
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
283+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
284284
else:
285-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
285+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
286286
if x_noisy.shape[0] != self.cond_hint.shape[0]:
287287
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
288288

@@ -311,8 +311,8 @@ def copy(self):
311311

312312

313313
class SparseCtrlAdvanced(ControlNetAdvanced):
314-
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
315-
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
314+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
315+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
316316
self.control_model_wrapped = SparseModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
317317
self.add_compatible_weight(ControlWeightType.SPARSECTRL)
318318
self.control_model: SparseControlNet = self.control_model # does nothing except help with IDE hints
@@ -377,25 +377,25 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
377377
self.local_sparse_idxs_inverse.remove(actual_i)
378378
# sub_cond_hint now contains the hints relevant to current x_noisy
379379
if hint_order is None:
380-
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(self.device)
380+
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(x_noisy.device)
381381
else:
382-
sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(self.device)
382+
sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(x_noisy.device)
383383
# scale cond_hints to match noisy input
384384
if self.control_model.use_simplified_conditioning_embedding:
385385
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
386386
sub_cond_hint = self.model_latent_format.process_in(sub_cond_hint) # multiplies by model scale factor
387-
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(self.device)
387+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(x_noisy.device)
388388
else:
389389
# other SparseCtrl; inputs are typical images
390-
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
390+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
391391
# prepare cond_hint (b, c, h ,w)
392392
cond_shape = list(sub_cond_hint.shape)
393393
cond_shape[0] = len(range_idxs)
394-
self.cond_hint = torch.zeros(cond_shape).to(dtype).to(self.device)
394+
self.cond_hint = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
395395
self.cond_hint[local_idxs] = sub_cond_hint[:]
396396
# prepare cond_mask (b, 1, h, w)
397397
cond_shape[1] = 1
398-
cond_mask = torch.zeros(cond_shape).to(dtype).to(self.device)
398+
cond_mask = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
399399
cond_mask[local_idxs] = self.sparse_settings.sparse_mask_mult * self.weights.extras.get(SparseConst.MASK_MULT, 1.0)
400400
# combine cond_hint and cond_mask into (b, c+1, h, w)
401401
if not self.sparse_settings.merged:
@@ -446,7 +446,7 @@ def cleanup_advanced(self):
446446
self.local_sparse_idxs_inverse = None
447447

448448
def copy(self):
449-
c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.device, self.load_device, self.manual_cast_dtype)
449+
c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.load_device, self.manual_cast_dtype)
450450
self.copy_to(c)
451451
self.copy_to_advanced(c)
452452
return c

adv_control/control_lllite.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
340340
actual_cond_hint_orig = self.cond_hint_original
341341
if self.cond_hint_original.size(0) < self.full_latent_length:
342342
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
343-
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
343+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
344344
else:
345-
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
345+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
346346
if x_noisy.shape[0] != self.cond_hint.shape[0]:
347347
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
348348
# some special logic here compared to other controlnets:

adv_control/control_plusplus.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=N
223223

224224

225225
class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase):
226-
def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
227-
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
226+
def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
227+
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
228228
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
229229
self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS)
230230
# for IDE type hint purposes
@@ -319,13 +319,13 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number):
319319
self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
320320
else:
321321
self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
322-
self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=self.device, dtype=dtype)
322+
self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=x_noisy.device, dtype=dtype)
323323
self.cond_hint_shape = self.cond_hint[pp_idx].shape
324324
# prepare cond_hint_controls to match batchsize
325325
if self.cond_hint_types.count_nonzero() == 0:
326326
self.cond_hint_types = None
327327
else:
328-
self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=self.device, dtype=dtype).repeat(x_noisy.shape[0], 1)
328+
self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=x_noisy.device, dtype=dtype).repeat(x_noisy.shape[0], 1)
329329
for i in range(len(self.cond_hint)):
330330
if self.cond_hint[i] is not None:
331331
if x_noisy.shape[0] != self.cond_hint[i].shape[0]:

adv_control/control_reference.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def __init__(self, condhint: Tensor):
141141
class ReferenceAdvanced(ControlBase, AdvancedControlBase):
142142
CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4}
143143

144-
def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
145-
super().__init__(device)
144+
def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup):
145+
super().__init__()
146146
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
147147
# TODO: allow vae_optional to be used instead of preprocessor
148148
#require_vae=True
@@ -261,11 +261,11 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
261261
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
262262
self.cond_hint = comfy.utils.common_upscale(
263263
self.cond_hint_original[self.sub_idxs],
264-
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
264+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
265265
else:
266266
self.cond_hint = comfy.utils.common_upscale(
267267
self.cond_hint_original,
268-
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
268+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
269269
if x_noisy.shape[0] != self.cond_hint.shape[0]:
270270
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
271271
# noise cond_hint based on sigma (current step)

adv_control/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond,
931931
# default dtype to be same as x_noisy
932932
if dtype is None:
933933
dtype = x_noisy.dtype
934-
setattr(self, attr_name, out_mask.to(dtype=dtype).to(self.device))
934+
setattr(self, attr_name, out_mask.to(dtype=dtype).to(x_noisy.device))
935935
del out_mask
936936

937937
def _reset_attr(self, attr_name, new_value=None):

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-advanced-controlnet"
33
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
4-
version = "1.2.3"
4+
version = "1.3.0"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)