Skip to content

Commit

Permalink
Merge PR #382 from Kosinkadink/develop - lowvram fix
Browse files Browse the repository at this point in the history
Made MotionModelPatcher handle pe's properly with lowvram
  • Loading branch information
Kosinkadink authored May 18, 2024
2 parents 19cfb6b + 565261a commit ae9bc7b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
23 changes: 19 additions & 4 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,26 @@ def __init__(self, *args, **kwargs):
self.was_within_range = False
self.prev_sub_idxs = None
self.prev_batched_number = None

def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs):
patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs)

# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
remaining_tensors = list(self.model.state_dict().keys())
named_modules = []
for n, _ in self.model.named_modules():
named_modules.append(n)
named_modules.append(f"{n}.weight")
named_modules.append(f"{n}.bias")
for name in named_modules:
if name in remaining_tensors:
remaining_tensors.remove(name)

for key in remaining_tensors:
self.patch_weight_to_device(key, device_to)
if device_to is not None:
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to))

def patch_model(self, *args, **kwargs):
# patch as normal; used to need to do prepare_weights call to work with lowvram, but no longer needed
# will consider removing this override at some point since it does nothing at the moment
patched_model = super().patch_model(*args, **kwargs)
return patched_model

def pre_run(self, model: ModelPatcherAndInjector):
Expand Down
6 changes: 3 additions & 3 deletions animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule
self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds
self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames
self.orig_groupnorm_manual_cast_forward = comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights
self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights
self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
Expand All @@ -267,7 +267,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
if not (info.mm_version == AnimateDiffVersion.V3 or
(info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)):
torch.nn.GroupNorm.forward = groupnorm_mm_factory(params)
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True)
# if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack
try:
if model.load_device.type == "mps":
Expand All @@ -293,7 +293,7 @@ def restore_functions(self, model: ModelPatcherAndInjector):
model.model.memory_required = self.orig_memory_required
openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed
torch.nn.GroupNorm.forward = self.orig_groupnorm_forward
comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_manual_cast_forward
comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights
comfy.samplers.sampling_function = self.orig_sampling_function
comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.0"
version = "1.0.1"
license = "LICENSE"
dependencies = []

Expand Down

0 comments on commit ae9bc7b

Please sign in to comment.