Skip to content

Commit

Permalink
cleanup conflicting type import
Browse files Browse the repository at this point in the history
  • Loading branch information
Kosinkadink committed Aug 16, 2024
1 parent 3c4d5d4 commit 541ece5
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions animatediff/motion_module_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterable, Tuple, Union, TYPE_CHECKING
import re
from dataclasses import dataclass
from collections.abc import Iterable
from collections.abc import Iterable as IterColl

import torch
from einops import rearrange, repeat
Expand Down Expand Up @@ -886,9 +886,6 @@ def set_video_length(self, video_length: int, full_length: int):
self.full_length = full_length

def set_scale_multiplier(self, idx: int, multiplier: Union[float, list[float], None]):
# if not isinstance(multiplier, Iterable):
# multiplier = [multiplier]
# multiplier = extend_list_to_batch_size(multiplier, self.get_attention_count())
for block in self.transformer_blocks:
block.set_scale_multiplier(idx, multiplier)

Expand All @@ -907,7 +904,7 @@ def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[lis
scale = scales
break

if type(scale) == Tensor or not isinstance(scale, Iterable):
if type(scale) == Tensor or not isinstance(scale, IterColl):
scale = [scale]
scale = extend_list_to_batch_size(scale, self.get_attention_count())
for idx, sub_scale in enumerate(scale):
Expand Down

0 comments on commit 541ece5

Please sign in to comment.