27
27
from comfy .model_patcher import ModelPatcher
28
28
import comfy .ops
29
29
import comfy .model_management
30
+ import comfy .utils
30
31
31
32
from .logger import logger
32
33
from .utils import (BIGMAX , AbstractPreprocWrapper , disable_weight_init_clean_groupnorm ,
@@ -118,6 +119,7 @@ def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
118
119
to_return = super ().load (device_to = device_to , lowvram_model_memory = lowvram_model_memory , * args , ** kwargs )
119
120
if lowvram_model_memory > 0 :
120
121
self ._patch_lowvram_extras (device_to = device_to )
122
+ self ._handle_float8_pe_tensors ()
121
123
return to_return
122
124
123
125
def _patch_lowvram_extras (self , device_to = None ):
@@ -138,6 +140,18 @@ def _patch_lowvram_extras(self, device_to=None):
138
140
if device_to is not None :
139
141
comfy .utils .set_attr (self .model .motion_wrapper , key , comfy .utils .get_attr (self .model .motion_wrapper , key ).to (device_to ))
140
142
143
+ def _handle_float8_pe_tensors (self ):
144
+ if self .model .motion_wrapper is not None :
145
+ remaining_tensors = list (self .model .motion_wrapper .state_dict ().keys ())
146
+ pe_tensors = [x for x in remaining_tensors if '.pe' in x ]
147
+ is_first = True
148
+ for key in pe_tensors :
149
+ if is_first :
150
+ is_first = False
151
+ if comfy .utils .get_attr (self .model .motion_wrapper , key ).dtype not in [torch .float8_e5m2 , torch .float8_e4m3fn ]:
152
+ break
153
+ comfy .utils .set_attr (self .model .motion_wrapper , key , comfy .utils .get_attr (self .model .motion_wrapper , key ).half ())
154
+
141
155
# NOTE: no longer called by ComfyUI, but here for backwards compatibility
142
156
def patch_model_lowvram (self , device_to = None , * args , ** kwargs ):
143
157
patched_model = super ().patch_model_lowvram (device_to , * args , ** kwargs )
0 commit comments