diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index b69e8a2..e4cf43f 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -702,11 +702,26 @@ def __init__(self, *args, **kwargs): self.was_within_range = False self.prev_sub_idxs = None self.prev_batched_number = None + + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs): + patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs) + + # figure out the tensors (likely pe's) that should be cast to device besides just the named_modules + remaining_tensors = list(self.model.state_dict().keys()) + named_modules = [] + for n, _ in self.model.named_modules(): + named_modules.append(n) + named_modules.append(f"{n}.weight") + named_modules.append(f"{n}.bias") + for name in named_modules: + if name in remaining_tensors: + remaining_tensors.remove(name) + + for key in remaining_tensors: + self.patch_weight_to_device(key, device_to) + if device_to is not None: + comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) - def patch_model(self, *args, **kwargs): - # patch as normal; used to need to do prepare_weights call to work with lowvram, but no longer needed - # will consider removing this override at some point since it does nothing at the moment - patched_model = super().patch_model(*args, **kwargs) return patched_model def pre_run(self, model: ModelPatcherAndInjector): diff --git a/animatediff/sampling.py b/animatediff/sampling.py index c711527..991129b 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -249,7 +249,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames - self.orig_groupnorm_manual_cast_forward = comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights + self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult if SAMPLE_FALLBACK: # for backwards compatibility, for now @@ -267,7 +267,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara if not (info.mm_version == AnimateDiffVersion.V3 or (info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)): torch.nn.GroupNorm.forward = groupnorm_mm_factory(params) - comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: if model.load_device.type == "mps": @@ -293,7 +293,7 @@ def restore_functions(self, model: ModelPatcherAndInjector): model.model.memory_required = self.orig_memory_required openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed torch.nn.GroupNorm.forward = self.orig_groupnorm_forward - comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_manual_cast_forward + comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights comfy.samplers.sampling_function = self.orig_sampling_function comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult if SAMPLE_FALLBACK: # for backwards compatibility, for now diff --git a/pyproject.toml b/pyproject.toml index 5c6ebc5..8d755d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.0.0" +version = "1.0.1" license = "LICENSE" dependencies = []