Skip to content

Commit 3d00251

Browse files
committed
modified prepare_mask_batch to work with sd3
1 parent f5a149c commit 3d00251

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

adv_control/utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,15 @@ class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm):
429429

430430

431431
# adapted from comfy/sample.py
432-
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False):
432+
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False, match_shape=False):
433433
mask = mask.clone()
434-
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear")
434+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*multiplier, shape[-1]*multiplier), mode="bilinear")
435435
if match_dim1:
436+
if match_shape and len(shape) < 4:
437+
raise Exception(f"match_dim1 cannot be True if shape is under 4 dims; was {len(shape)}.")
436438
mask = torch.cat([mask] * shape[1], dim=1)
439+
if match_shape and len(shape) == 3:
440+
mask = mask.squeeze(1)
437441
return mask
438442

439443

@@ -823,10 +827,10 @@ def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
823827
x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number)
824828
# apply masks, resizing mask to required dims
825829
if self.mask_cond_hint is not None:
826-
masks = prepare_mask_batch(self.mask_cond_hint, x.shape)
830+
masks = prepare_mask_batch(self.mask_cond_hint, x.shape, match_shape=True)
827831
x[:] = x[:] * masks
828832
if self.tk_mask_cond_hint is not None:
829-
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape)
833+
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape, match_shape=True)
830834
x[:] = x[:] * masks
831835
# apply timestep keyframe strengths
832836
if self._current_timestep_keyframe.strength != 1.0:

0 commit comments

Comments
 (0)