Skip to content

Commit

Permalink
Merge PR #387 from Kosinkadink/develop - new fuse_methods by Quasimondo
Browse files Browse the repository at this point in the history
Merging fuse_methods by Quasimondo
  • Loading branch information
Kosinkadink authored May 21, 2024
2 parents ae9bc7b + 225f069 commit eba6a50
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 7 deletions.
94 changes: 89 additions & 5 deletions animatediff/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ class ContextFuseMethod:
FLAT = "flat"
PYRAMID = "pyramid"
RELATIVE = "relative"
RANDOM = "random"
GAUSS_SIGMA = "gauss-sigma"
GAUSS_SIGMA_INV = "gauss-sigma inverse"
DELAYED_REVERSE_SAWTOOTH = "delayed reverse sawtooth"
PYRAMID_SIGMA = "pyramid-sigma"
PYRAMID_SIGMA_INV = "pyramid-sigma inverse"

LIST = [PYRAMID, FLAT]
LIST_STATIC = [PYRAMID, RELATIVE, FLAT]
LIST = [PYRAMID, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM]


class ContextType:
Expand Down Expand Up @@ -308,18 +314,17 @@ def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, Contex
}


def get_context_weights(num_frames: int, fuse_method: str):
def get_context_weights(num_frames: int, fuse_method: str, sigma: Tensor = None):
weights_func = FUSE_MAPPING.get(fuse_method, None)
if not weights_func:
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
return weights_func(num_frames)
return weights_func(num_frames, sigma=sigma )


def create_weights_flat(length: int, **kwargs) -> list[float]:
# weight is the same for all
return [1.0] * length


def create_weights_pyramid(length: int, **kwargs) -> list[float]:
# weight is based on the distance away from the edge of the context window;
# based on weighted average concept in FreeNoise paper
Expand All @@ -331,11 +336,90 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence

def create_weights_random(length: int, **kwargs) -> list[float]:
if length % 2 == 0:
max_weight = length // 2
else:
max_weight = (length + 1) // 2
return list(np.random.random(length)*max_weight+0.001)

def create_weights_gauss_sigma(length: int, **kwargs) -> list[float]:
sigma = 1.0 + 8.0*(min(4.0, kwargs["sigma"].mean().cpu()) / 4.0)
ax = np.linspace(-(length - 1) / 2., (length - 1) / 2., length)
w = np.exp(-0.5 * np.square(ax) / np.square(sigma))
if length % 2 == 0:
max_weight = length // 2
else:
max_weight = (length + 1) // 2
w *= max_weight / np.linalg.norm(w)
#print("create_weights_gauss_sigma sigma",sigma,w)
return list(w)

def create_weights_gauss_sigma_inv(length: int, **kwargs) -> list[float]:
sigma = 1.0 + 8.0*(1.0-min(4.0, kwargs["sigma"].mean().cpu()) / 4.0)
ax = np.linspace(-(length - 1) / 2., (length - 1) / 2., length)
w = np.exp(-0.5 * np.square(ax) / np.square(sigma))
if length % 2 == 0:
max_weight = length // 2
else:
max_weight = (length + 1) // 2
w *= max_weight / np.linalg.norm(w)
#print("create_weights_gauss_sigma_inv sigma",sigma,w)
return list(w)

def create_weights_pyramid_sigma_inv(length: int, **kwargs) -> list[float]:
sigma = min(4.0, kwargs["sigma"].mean().cpu()) / 4.0

if length % 2 == 0:
max_weight = length // 2
weight_sequence = np.array(list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)))
weight_sequence2 = np.array([-max_weight]*(max_weight-1) +[max_weight,max_weight] + [-max_weight]*(max_weight-1))
else:
max_weight = (length + 1) // 2
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
weight_sequence2 = np.array([-max_weight]*(max_weight) +[max_weight] + [-max_weight]*(max_weight-1))
weight_sequence = (sigma * weight_sequence2 + (1.0-sigma) * weight_sequence).clip(0.001,max_weight)
#print("create_weights_pyramid_sigma_inv",kwargs["sigma"].mean(),sigma, len(weight_sequence),weight_sequence)
return list(weight_sequence)

def create_weights_pyramid_sigma(length: int, **kwargs) -> list[float]:
sigma = min(4.0, kwargs["sigma"].mean().cpu()) / 4.0

if length % 2 == 0:
max_weight = length // 2
weight_sequence = np.array(list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)))
weight_sequence2 = np.array([-max_weight]*(max_weight-1) +[max_weight,max_weight] + [-max_weight]*(max_weight-1))
else:
max_weight = (length + 1) // 2
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
weight_sequence2 = np.array([-max_weight]*(max_weight) +[max_weight] + [-max_weight]*(max_weight-1))
weight_sequence = (sigma * weight_sequence + (1.0-sigma) * weight_sequence2).clip(0.001,max_weight)
#print("create_weights_pyramid_sigma",kwargs["sigma"].mean(),sigma, len(weight_sequence),weight_sequence)
return list(weight_sequence)

def create_weights_delayed_reverse_sawtooth(length: int, **kwargs) -> list[float]:
# assigns 0.01 to first half (or half-1 if even) of weights, then the rest of the weights are basically
# based on distance from context edge
if length % 2 == 0:
max_weight = length // 2
weight_sequence = [0.01]*(max_weight-1) + [max_weight] + list(range(max_weight, 0, -1))
else:
max_weight = (length + 1) // 2
weight_sequence = [0.01]*max_weight + [max_weight] + list(range(max_weight - 1, 0, -1))
#print("create_weights_delayed_falling_edge",len(weight_sequence),weight_sequence)
return weight_sequence


FUSE_MAPPING = {
ContextFuseMethod.FLAT: create_weights_flat,
ContextFuseMethod.PYRAMID: create_weights_pyramid,
ContextFuseMethod.RELATIVE: create_weights_pyramid,
ContextFuseMethod.GAUSS_SIGMA: create_weights_gauss_sigma,
ContextFuseMethod.GAUSS_SIGMA_INV: create_weights_gauss_sigma_inv,
ContextFuseMethod.RANDOM: create_weights_random,
ContextFuseMethod.DELAYED_REVERSE_SAWTOOTH: create_weights_delayed_reverse_sawtooth,
ContextFuseMethod.PYRAMID_SIGMA: create_weights_pyramid_sigma,
ContextFuseMethod.PYRAMID_SIGMA_INV: create_weights_pyramid_sigma_inv,
}


Expand Down
2 changes: 1 addition & 1 deletion animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list
biases_final[i][(full_length*n)+idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method) * batched_conds
weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method, sigma=timestep) * batched_conds
weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
for i in range(len(sub_conds_out)):
conds_final[i][full_idxs] += sub_conds_out[i] * weights_tensor
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.1"
version = "1.0.2"
license = "LICENSE"
dependencies = []

Expand Down

0 comments on commit eba6a50

Please sign in to comment.