@@ -429,11 +429,15 @@ class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm):
429
429
430
430
431
431
# adapted from comfy/sample.py
432
- def prepare_mask_batch (mask : Tensor , shape : Tensor , multiplier : int = 1 , match_dim1 = False ):
432
+ def prepare_mask_batch (mask : Tensor , shape : Tensor , multiplier : int = 1 , match_dim1 = False , match_shape = False ):
433
433
mask = mask .clone ()
434
- mask = torch .nn .functional .interpolate (mask .reshape ((- 1 , 1 , mask .shape [- 2 ], mask .shape [- 1 ])), size = (shape [2 ]* multiplier , shape [3 ]* multiplier ), mode = "bilinear" )
434
+ mask = torch .nn .functional .interpolate (mask .reshape ((- 1 , 1 , mask .shape [- 2 ], mask .shape [- 1 ])), size = (shape [- 2 ]* multiplier , shape [- 1 ]* multiplier ), mode = "bilinear" )
435
435
if match_dim1 :
436
+ if match_shape and len (shape ) < 4 :
437
+ raise Exception (f"match_dim1 cannot be True if shape is under 4 dims; was { len (shape )} ." )
436
438
mask = torch .cat ([mask ] * shape [1 ], dim = 1 )
439
+ if match_shape and len (shape ) == 3 :
440
+ mask = mask .squeeze (1 )
437
441
return mask
438
442
439
443
@@ -823,10 +827,10 @@ def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
823
827
x [:] = x [:] * self .calc_latent_keyframe_mults (x = x , batched_number = batched_number )
824
828
# apply masks, resizing mask to required dims
825
829
if self .mask_cond_hint is not None :
826
- masks = prepare_mask_batch (self .mask_cond_hint , x .shape )
830
+ masks = prepare_mask_batch (self .mask_cond_hint , x .shape , match_shape = True )
827
831
x [:] = x [:] * masks
828
832
if self .tk_mask_cond_hint is not None :
829
- masks = prepare_mask_batch (self .tk_mask_cond_hint , x .shape )
833
+ masks = prepare_mask_batch (self .tk_mask_cond_hint , x .shape , match_shape = True )
830
834
x [:] = x [:] * masks
831
835
# apply timestep keyframe strengths
832
836
if self ._current_timestep_keyframe .strength != 1.0 :
0 commit comments