Skip to content

Commit

Permalink
Merge pull request #393 from Kosinkadink/develop
Browse files Browse the repository at this point in the history
Prepare for ComfyUI to remove model_keys prop from ModelPatcher
  • Loading branch information
Kosinkadink authored Jun 3, 2024
2 parents eba6a50 + 5d9ffaa commit e2313c4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
28 changes: 18 additions & 10 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, m: ModelPatcher):

self.object_patches = m.object_patches.copy()
self.model_options = copy.deepcopy(m.model_options)
self.model_keys = m.model_keys
if hasattr(m, "model_keys"):
self.model_keys = m.model_keys
if hasattr(m, "backup"):
self.backup = m.backup
if hasattr(m, "object_patches_backup"):
Expand Down Expand Up @@ -141,7 +142,8 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
if key in self.model_keys:
model_sd = self.model.state_dict()
if key in model_sd:
p.add(key)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[key], strength_model))
Expand All @@ -159,7 +161,8 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {})
p = set()
for key in patches:
if key in self.model_keys:
model_sd = self.model.state_dict()
if key in model_sd:
p.add(key)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
Expand Down Expand Up @@ -423,7 +426,8 @@ def __init__(self, m: ModelPatcher):

self.object_patches = m.object_patches.copy()
self.model_options = copy.deepcopy(m.model_options)
self.model_keys = m.model_keys
if hasattr(m, "model_keys"):
self.model_keys = m.model_keys
if hasattr(m, "backup"):
self.backup = m.backup
if hasattr(m, "object_patches_backup"):
Expand Down Expand Up @@ -484,7 +488,8 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
p = set()
for key in patches:
if key in self.model_keys:
model_sd = self.model.state_dict()
if key in model_sd:
p.add(key)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[key], strength_model))
Expand All @@ -501,7 +506,8 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook, {})
p = set()
for key in patches:
if key in self.model_keys:
model_sd = self.model.state_dict()
if key in model_sd:
p.add(key)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
Expand Down Expand Up @@ -627,7 +633,7 @@ def load_model_as_hooked_lora_for_models(model: Union[ModelPatcher, ModelPatcher
if model is not None and model_loaded is not None:
new_modelpatcher = ModelPatcherAndInjector.create_from(model)
comfy.model_management.unload_model_clones(new_modelpatcher)
expected_model_keys = model_loaded.model_keys.copy()
expected_model_keys = set(model_loaded.model.state_dict().keys())
patches_model: dict[str, Tensor] = model_loaded.model.state_dict()
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
Expand All @@ -642,7 +648,7 @@ def load_model_as_hooked_lora_for_models(model: Union[ModelPatcher, ModelPatcher
if clip is not None and clip_loaded is not None:
new_clip = CLIPWithHooks(clip)
comfy.model_management.unload_model_clones(new_clip.patcher)
expected_clip_keys = clip_loaded.patcher.model_keys.copy()
expected_clip_keys = clip_loaded.patcher.model.state_dict().copy()
patches_clip: dict[str, Tensor] = clip_loaded.cond_stage_model.state_dict()
k1 = new_clip.add_hooked_patches_as_diffs(lora_hook=lora_hook, patches=patches_clip, strength_patch=strength_clip)
else:
Expand Down Expand Up @@ -889,7 +895,8 @@ def clone(self):

n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
if hasattr(n, "model_keys"):
n.model_keys = self.model_keys
if hasattr(n, "backup"):
self.backup = n.backup
if hasattr(n, "object_patches_backup"):
Expand Down Expand Up @@ -982,7 +989,8 @@ def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher:

model.object_patches = m.object_patches.copy()
model.model_options = copy.deepcopy(m.model_options)
model.model_keys = m.model_keys
if hasattr(model, "model_keys"):
model.model_keys = m.model_keys
return model


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.2"
version = "1.0.3"
license = "LICENSE"
dependencies = []

Expand Down

0 comments on commit e2313c4

Please sign in to comment.