@@ -62,7 +62,10 @@ def refcn_sample(model: ModelPatcher, *args, **kwargs):
62
62
for i , module in enumerate (attn_modules ):
63
63
injection_holder = InjectionBasicTransformerBlockHolder (block = module , idx = i )
64
64
injection_holder .attn_weight = float (i ) / float (len (attn_modules ))
65
- module ._forward = _forward_inject_BasicTransformerBlock .__get__ (module , type (module ))
65
+ if hasattr (module , "_forward" ): # backward compatibility
66
+ module ._forward = _forward_inject_BasicTransformerBlock .__get__ (module , type (module ))
67
+ else :
68
+ module .forward = _forward_inject_BasicTransformerBlock .__get__ (module , type (module ))
66
69
module .injection_holder = injection_holder
67
70
reference_injections .attn_modules .append (module )
68
71
# figure out which module is middle block
@@ -430,14 +433,20 @@ def clean(self):
430
433
431
434
class InjectionBasicTransformerBlockHolder :
432
435
def __init__ (self , block : BasicTransformerBlock , idx = None ):
433
- self .original_forward = block ._forward
436
+ if hasattr (block , "_forward" ): # backward compatibility
437
+ self .original_forward = block ._forward
438
+ else :
439
+ self .original_forward = block .forward
434
440
self .idx = idx
435
441
self .attn_weight = 1.0
436
442
self .is_middle = False
437
443
self .bank_styles = BankStylesBasicTransformerBlock ()
438
444
439
445
def restore (self , block : BasicTransformerBlock ):
440
- block ._forward = self .original_forward
446
+ if hasattr (block , "_forward" ): # backward compatibility
447
+ block ._forward = self .original_forward
448
+ else :
449
+ block .forward = self .original_forward
441
450
442
451
def clean (self ):
443
452
self .bank_styles .clean ()
0 commit comments