diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 19fda16..1830266 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -1148,6 +1148,8 @@ def forward( view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, ): + if scale_masks is None: + scale_masks = [None] * len(self.attention_blocks) # make view_options None if context_length > video_length, or if equal and equal not allowed if view_options: if view_options.context_length > video_length: