From ce99718068abc132a8c8e752185fcb305c167352 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 10 Jul 2024 21:29:33 -0500 Subject: [PATCH 1/3] Added Visualize Context Options nodes, fixed bug with view options not getting their step properly set --- animatediff/context.py | 186 +++++++++++++++++++++++++++++++++-- animatediff/nodes.py | 11 ++- animatediff/nodes_context.py | 76 ++++++++++++-- 3 files changed, 257 insertions(+), 16 deletions(-) diff --git a/animatediff/context.py b/animatediff/context.py index 63ddd12..aa369b3 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -1,12 +1,16 @@ from typing import Callable, Optional, Union +import torch import torchvision import PIL +from PIL import Image, ImageFont, ImageDraw import numpy as np from torch import Tensor +import comfy.samplers from comfy.model_base import BaseModel +from comfy.model_patcher import ModelPatcher from .utils_motion import get_sorted_list_via_attr @@ -76,7 +80,7 @@ def __init__(self): self._current_context: ContextOptions = None self._current_used_steps: int = 0 self._current_index: int = 0 - self.step = 0 + self._step = 0 def reset(self): self._current_context = None @@ -85,6 +89,15 @@ def reset(self): self.step = 0 self._set_first_as_current() + @property + def step(self): + return self._step + @step.setter + def step(self, value: int): + self._step = value + if self._current_context is not None: + self._current_context.step = value + @classmethod def default(cls): def_context = ContextOptions() @@ -492,15 +505,174 @@ class Colors: CYAN = (0, 255, 255) +class BorderWidth: + INDEXES = 2 + CONTEXT = 4 + + class VisualizeSettings: - def __init__(self, img_width, img_height, video_length): - self.img_width = img_width - self.img_height = img_height + def __init__(self, img_width: int, video_length: int): self.video_length = video_length + self.img_width = img_width self.grid = img_width // video_length + self.img_height = self.grid * 5 self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()]) + self.font_size = int(self.grid * 0.5) + self.font = ImageFont.load_default(size=self.font_size) + #self.title_font = ImageFont.load_default(size=int(self.font_size * 1.5)) + self.title_font = ImageFont.load_default(size=int(self.font_size * 1.2)) + self.background_color = Colors.BLACK + self.grid_outline_color = Colors.WHITE + self.start_idx_fill_color = Colors.MAGENTA + self.subidx_end_color = Colors.YELLOW -def generate_context_visualization(context_opts: ContextOptionsGroup, model: BaseModel, width=1440, height=200, video_length=32, start_step=0, end_step=20): - vs = VisualizeSettings(width, height, video_length) - pass + self.context_color = Colors.GREEN + self.view_color = Colors.RED + + +class GridDisplay: + def __init__(self, draw: ImageDraw.ImageDraw, vs: VisualizeSettings, home_x: int=0, home_y: int=0): + self.home_x = home_x + self.home_y = home_y + self.draw = draw + self.vs = vs + + +def get_text_xy(input: str, font: ImageFont, x: int, y: int, centered=True): + return (x, y,) + + +def draw_text(text: str, font: ImageFont, gd: GridDisplay, x: int, y: int, color=Colors.WHITE, centered=True): + x, y = get_text_xy(text, font, x, y, centered=centered) + gd.draw.text(xy=(gd.home_x+x, gd.home_y+y), text=text, fill=color, font=font) + + +def draw_first_grid_row(total_length: int, gd: GridDisplay, start_idx=-1): + vs = gd.vs + # the first row is white squares, with the indexes drawed in + for i in range(total_length): + x1 = gd.home_x+(vs.grid*i) + y1 = gd.home_y + x2 = x1 + vs.grid + y2 = y1 + vs.grid + + fill = None + if i==start_idx: + fill=vs.start_idx_fill_color + gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill, outline=vs.grid_outline_color, width=BorderWidth.INDEXES) + draw_text(text=str(i), font=vs.font, gd=gd, x=vs.grid*i, y=0) + + +def draw_subidxs(window: list[int], gd: GridDisplay, y_grid_offset: int, color: tuple): + vs = gd.vs + # with no indexes drawed in- just solid squares, mostly + y_offset = vs.grid * y_grid_offset + for i, val in enumerate(window): + x1 = gd.home_x+(vs.grid*val) + y1 = gd.home_y+y_offset + x2 = x1 + vs.grid + y2 = y1 + vs.grid + fill_color = color + # if at an end of indexes, make inside be different color + if i == 0 or i == len(window)-1: + fill_color = vs.subidx_end_color + gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill_color, outline=color, width=BorderWidth.CONTEXT) + + +def draw_context(window: list[int], gd: GridDisplay): + draw_subidxs(window=window, gd=gd, y_grid_offset=1, color=gd.vs.context_color) + + +def draw_view(window: list[int], gd: GridDisplay): + draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color) + + +def generate_context_visualization(context_opts: ContextOptionsGroup, model: ModelPatcher, sampler_name: str=None, scheduler: str=None, + width=1440, height=200, video_length=32, + steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None): + context_opts = context_opts.clone() + vs = VisualizeSettings(width, video_length) + all_imgs = [] + + if sigmas is None: + sampler = comfy.samplers.KSampler( + model=model, steps=steps, device="cpu", sampler=sampler_name, scheduler=scheduler, + denoise=denoise, model_options=model.model_options, + ) + sigmas = sampler.sigmas + if end_step is not None and end_step < (len(sigmas) - 1): + sigmas = sigmas[:end_step + 1] + if force_full_denoise: + sigmas[-1] = 0 + if start_step is not None: + if start_step < (len(sigmas) - 1): + sigmas = sigmas[start_step:] + # remove last sigma, as sampling uses pairs of sigmas at a time (fence post problem) + sigmas = sigmas[:-1] + + context_opts.reset() + context_opts.initialize_timesteps(model.model) + + if start_step is None: + start_step = 0 # use this in case start_step is provided, to display accurate step + + for i, t in enumerate(sigmas): + # make context_opts reflect current step/sigma + context_opts.prepare_current_context([t]) + context_opts.step = start_step+i + + # check if context should even be active in this case + context_active = True + if video_length < context_opts.context_length: + context_active = False + elif video_length == context_opts.context_length and not context_opts.use_on_equal_length: + context_active = False + + if context_active: + context_windows = get_context_windows(num_frames=video_length, opts=context_opts) + else: + context_windows = [list(range(video_length))] + start_idx = -1 + for j,window in enumerate(context_windows): + repeat_count = 0 + view_windows = [] + total_repeats = 1 + view_options = context_opts.view_options + if view_options is not None: + view_active = True + if len(window) < view_options.context_length: + view_active = False + elif video_length == view_options.context_length and not view_options.use_on_equal_length: + view_active = False + if view_active: + view_windows = get_context_windows(num_frames=len(window), opts=view_options) + total_repeats = len(view_windows) + while total_repeats > repeat_count: + # create new frame + frame: Image = Image.new(mode="RGB", size=(vs.img_width, vs.img_height), color=vs.background_color) + draw = ImageDraw.Draw(frame) + gd = GridDisplay(draw=draw, vs=vs, home_x=0, home_y=vs.grid) + # if views present, do view stuff + if len(view_windows) > 0: + converted_view = [window[x] for x in view_windows[repeat_count]] + draw_view(window=converted_view, gd=gd) + # draw context_type + current step + title_str = f"{context_opts.context_schedule} - Step {context_opts.step+1}/{steps} (Context {j+1}/{len(context_windows)})" + if len(view_windows) > 0: + title_str = f"{title_str} (View {repeat_count+1}/{len(view_windows)})" + draw_text(text=title_str, font=vs.title_font, gd=gd, x=0-gd.home_x, y=0-gd.home_y, centered=False) + # draw first row (total length, white) + if j == 0: + start_idx = window[0] + draw_first_grid_row(total_length=video_length, gd=gd, start_idx=start_idx) + # draw context row + draw_context(window=window, gd=gd) + # save image + iterate repeat_count + img: Tensor = vs.pil_to_tensor(frame) + all_imgs.append(img) + repeat_count += 1 + + images = torch.stack(all_imgs) + images = images.movedim(1, -1) + return images diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 363adc5..5597df2 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -25,7 +25,8 @@ NoisedImageInjectionNode, NoisedImageInjectOptionsNode) from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode) from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, - StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, VisualizeContextOptionsInt) + StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, + VisualizeContextOptionsK, VisualizeContextOptionsKAdv, VisualizeContextOptionsSCustom) from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode, WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode, WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode) @@ -58,7 +59,9 @@ "ADE_ViewsOnlyContextOptions": ViewAsContextOptionsNode, "ADE_BatchedContextOptions": BatchedContextOptionsNode, "ADE_AnimateDiffUniformContextOptions": LegacyLoopedUniformContextOptionsNode, # Legacy - #"ADE_VisualizeContextOptions": VisualizeContextOptionsInt, + "ADE_VisualizeContextOptionsK": VisualizeContextOptionsK, + "ADE_VisualizeContextOptionsKAdv": VisualizeContextOptionsKAdv, + "ADE_VisualizeContextOptionsSCustom": VisualizeContextOptionsSCustom, # View Opts "ADE_StandardStaticViewOptions": StandardStaticViewOptionsNode, "ADE_StandardUniformViewOptions": StandardUniformViewOptionsNode, @@ -180,7 +183,9 @@ "ADE_ViewsOnlyContextOptions": "Context Optionsโ—†Views Only [VRAMโ‡ˆ] ๐ŸŽญ๐Ÿ…๐Ÿ…“", "ADE_BatchedContextOptions": "Context Optionsโ—†Batched [Non-AD] ๐ŸŽญ๐Ÿ…๐Ÿ…“", "ADE_AnimateDiffUniformContextOptions": "Context Optionsโ—†Looped Uniform ๐ŸŽญ๐Ÿ…๐Ÿ…“", # Legacy - "ADE_VisualizeContextOptions": "Visualize Context Options ๐ŸŽญ๐Ÿ…๐Ÿ…“", + "ADE_VisualizeContextOptionsK": "Visualize Context Options (K.) ๐ŸŽญ๐Ÿ…๐Ÿ…“", + "ADE_VisualizeContextOptionsKAdv": "Visualize Context Options (K.Adv.) ๐ŸŽญ๐Ÿ…๐Ÿ…“", + "ADE_VisualizeContextOptionsSCustom": "Visualize Context Options (S.Cus.) ๐ŸŽญ๐Ÿ…๐Ÿ…“", # View Opts "ADE_StandardStaticViewOptions": "View Optionsโ—†Standard Static ๐ŸŽญ๐Ÿ…๐Ÿ…“", "ADE_StandardUniformViewOptions": "View Optionsโ—†Standard Uniform ๐ŸŽญ๐Ÿ…๐Ÿ…“", diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 0940b76..a4ca94c 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -1,10 +1,12 @@ import torch from torch import Tensor +import comfy.samplers from comfy.model_patcher import ModelPatcher -from .context import ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules -from .utils_model import BIGMAX +from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, + generate_context_visualization) +from .utils_model import BIGMAX, MAX_RESOLUTION LENGTH_MAX = 128 # keep an eye on these max values; @@ -353,16 +355,20 @@ def create_options(self, view_length: int, view_overlap: int, view_stride: int, return (view_options,) -class VisualizeContextOptionsInt: +class VisualizeContextOptionsKAdv: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "context_opts": ("CONTEXT_OPTIONS",), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), }, "optional": { + "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), + "steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}), "start_step": ("INT", {"min": 0, "max": BIGMAX, "default": 0}), "end_step": ("INT", {"min": 1, "max": BIGMAX, "default": 20}), } @@ -372,7 +378,65 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff ๐ŸŽญ๐Ÿ…๐Ÿ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, - latents_length=32, start_step=0, end_step=20): - images = torch.zeros((latents_length, 256, 256, 3)) + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20): + images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + sampler_name=sampler_name, scheduler=scheduler, + steps=steps, start_step=start_step, end_step=end_step) + return (images,) + + +class VisualizeContextOptionsK: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "context_opts": ("CONTEXT_OPTIONS",), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + }, + "optional": { + "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), + "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), + "steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("IMAGE",) + CATEGORY = "Animate Diff ๐ŸŽญ๐Ÿ…๐Ÿ…“/context opts/visualize" + FUNCTION = "visualize" + + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + visual_width: 1280, latents_length=32, steps=20, denoise=1.0): + images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + sampler_name=sampler_name, scheduler=scheduler, + steps=steps, denoise=denoise) + return (images,) + + +class VisualizeContextOptionsSCustom: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "context_opts": ("CONTEXT_OPTIONS",), + "sigmas": ("SIGMAS", ), + }, + "optional": { + "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), + "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), + } + } + + RETURN_TYPES = ("IMAGE",) + CATEGORY = "Animate Diff ๐ŸŽญ๐Ÿ…๐Ÿ…“/context opts/visualize" + FUNCTION = "visualize" + + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas, + visual_width: 1280, latents_length=32): + images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + sigmas=sigmas) return (images,) From 8ba29239bd540fd8fb1b6f28510161b2af9553ec Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 10 Jul 2024 22:57:41 -0500 Subject: [PATCH 2/3] Fix visualization images not being in float32, fix steps showing up as None when using S. Cus. node --- animatediff/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/animatediff/context.py b/animatediff/context.py index aa369b3..b3a61fd 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -616,6 +616,8 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod if start_step is None: start_step = 0 # use this in case start_step is provided, to display accurate step + if steps is None: + steps = len(sigmas) for i, t in enumerate(sigmas): # make context_opts reflect current step/sigma @@ -674,5 +676,5 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod repeat_count += 1 images = torch.stack(all_imgs) - images = images.movedim(1, -1) + images = images.movedim(1, -1).to(torch.float32) return images From c8480f9b5715ac867d8003ead83d3f7ffda266d4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 10 Jul 2024 23:03:57 -0500 Subject: [PATCH 3/3] Cleaned imports, version bump --- animatediff/context.py | 3 +-- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/animatediff/context.py b/animatediff/context.py index b3a61fd..b44fa42 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -1,8 +1,7 @@ -from typing import Callable, Optional, Union +from typing import Union import torch import torchvision -import PIL from PIL import Image, ImageFont, ImageDraw import numpy as np diff --git a/pyproject.toml b/pyproject.toml index 5bd4395..30a7309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.0.9" +version = "1.0.10" license = "LICENSE" dependencies = []