Skip to content

Commit 74d0c56

Browse files
authored
Merge PR #170 from Kosinkadink/develop - fix fp8 support for SparseCtrl
Fix fp8 support for SparseCtrl
2 parents dcc928b + 50db1e6 commit 74d0c56

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

adv_control/control_sparsectrl.py

+14
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from comfy.model_patcher import ModelPatcher
2828
import comfy.ops
2929
import comfy.model_management
30+
import comfy.utils
3031

3132
from .logger import logger
3233
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm,
@@ -118,6 +119,7 @@ def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
118119
to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
119120
if lowvram_model_memory > 0:
120121
self._patch_lowvram_extras(device_to=device_to)
122+
self._handle_float8_pe_tensors()
121123
return to_return
122124

123125
def _patch_lowvram_extras(self, device_to=None):
@@ -138,6 +140,18 @@ def _patch_lowvram_extras(self, device_to=None):
138140
if device_to is not None:
139141
comfy.utils.set_attr(self.model.motion_wrapper, key, comfy.utils.get_attr(self.model.motion_wrapper, key).to(device_to))
140142

143+
def _handle_float8_pe_tensors(self):
144+
if self.model.motion_wrapper is not None:
145+
remaining_tensors = list(self.model.motion_wrapper.state_dict().keys())
146+
pe_tensors = [x for x in remaining_tensors if '.pe' in x]
147+
is_first = True
148+
for key in pe_tensors:
149+
if is_first:
150+
is_first = False
151+
if comfy.utils.get_attr(self.model.motion_wrapper, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]:
152+
break
153+
comfy.utils.set_attr(self.model.motion_wrapper, key, comfy.utils.get_attr(self.model.motion_wrapper, key).half())
154+
141155
# NOTE: no longer called by ComfyUI, but here for backwards compatibility
142156
def patch_model_lowvram(self, device_to=None, *args, **kwargs):
143157
patched_model = super().patch_model_lowvram(device_to, *args, **kwargs)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-advanced-controlnet"
33
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
4-
version = "1.2.2"
4+
version = "1.2.3"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)