@@ -114,9 +114,13 @@ def __init__(self, *args, **kwargs):
114
114
self .model : SparseControlNet
115
115
super ().__init__ (* args , ** kwargs )
116
116
117
- def patch_model_lowvram (self , device_to = None , * args , ** kwargs ):
118
- patched_model = super ().patch_model_lowvram (device_to , * args , ** kwargs )
117
+ def load (self , device_to = None , lowvram_model_memory = 0 , * args , ** kwargs ):
118
+ to_return = super ().load (device_to = device_to , lowvram_model_memory = lowvram_model_memory , * args , ** kwargs )
119
+ if lowvram_model_memory > 0 :
120
+ self ._patch_lowvram_extras (device_to = device_to )
121
+ return to_return
119
122
123
+ def _patch_lowvram_extras (self , device_to = None ):
120
124
if self .model .motion_wrapper is not None :
121
125
# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
122
126
remaining_tensors = list (self .model .motion_wrapper .state_dict ().keys ())
@@ -134,6 +138,10 @@ def patch_model_lowvram(self, device_to=None, *args, **kwargs):
134
138
if device_to is not None :
135
139
comfy .utils .set_attr (self .model .motion_wrapper , key , comfy .utils .get_attr (self .model .motion_wrapper , key ).to (device_to ))
136
140
141
+ # NOTE: no longer called by ComfyUI, but here for backwards compatibility
142
+ def patch_model_lowvram (self , device_to = None , * args , ** kwargs ):
143
+ patched_model = super ().patch_model_lowvram (device_to , * args , ** kwargs )
144
+ self ._patch_lowvram_extras (device_to = device_to )
137
145
return patched_model
138
146
139
147
def clone (self ):
0 commit comments