11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import TYPE_CHECKING , Optional
14
+ from typing import TYPE_CHECKING
15
15
16
16
17
17
if TYPE_CHECKING :
20
20
import torch
21
21
import torch .nn as nn
22
22
23
- from ...utils import logging , recurse_setattr
23
+ from ...utils import logging , recurse_getattr , recurse_setattr
24
24
25
25
26
26
KNOWN_ACTIVATION_ATTRIBUTES = ["hidden_act" , "activation" , "act_fn" , "activation_function" ]
@@ -38,7 +38,6 @@ class BetterTransformerBaseLayer(nn.Module):
38
38
def __init__ (
39
39
self ,
40
40
config : "PretrainedConfig" ,
41
- orig_layer : Optional [nn .Module ] = None ,
42
41
):
43
42
r"""
44
43
Base layer for `BetterTransformer` integration. This class is used to wrap all the necessary
@@ -57,6 +56,7 @@ def __init__(
57
56
self .embed_dim = None
58
57
self .num_layers = None
59
58
self .original_layers_mapping = {}
59
+ self .module_mapping = None
60
60
self .is_decoder = False
61
61
# Some models does not have some attributes thus needs to be ignored
62
62
# e.g. whisper does not have self_attn.k_proj.bias but has self_attn.v_proj.bias & self_attn.q_proj.bias
@@ -84,14 +84,6 @@ def __init__(
84
84
self .num_layers = getattr (config , attr )
85
85
break
86
86
87
- # TODO: re-enable once fixed
88
- # if orig_layer is not None:
89
- # # Last step, store the old module skeleton by copying the old module and putting
90
- # # it on the `meta` device.
91
- # self.orig_layer = deepcopy(orig_layer).to("meta")
92
- # else:
93
- # self.orig_layer = orig_layer
94
-
95
87
def validate_bettertransformer (self ):
96
88
r"""
97
89
A wrapper function to validate the `BetterTransformer` implementation. Implements most relevant checks
@@ -147,32 +139,45 @@ def forward_checker(self, *args, **kwargs):
147
139
" Please use `model.eval()` before running the model." ,
148
140
)
149
141
150
- def _revert_back_to_original_module (self ):
151
- r"""
152
- A wrapper function to replace the current layer with the previous non-BetterTransformer
153
- layer.
154
- """
142
+ def _revert (self , module : torch .nn .Module ) -> torch .nn .Module :
143
+ if self .module_mapping is not None :
144
+ if "" in self .module_mapping .values ():
145
+ for bt_module_attr_name , value in self .module_mapping .items ():
146
+ if value == "" :
147
+ module = getattr (self , bt_module_attr_name )
148
+ return module
149
+ else :
150
+ raise NotImplementedError ("replacing a submodule in revert is not supported" )
151
+
155
152
for modified_layer_key_names , original_layer_key_names in self .original_layers_mapping .items ():
156
153
if isinstance (original_layer_key_names , list ):
157
154
current_weight = getattr (self , modified_layer_key_names )
158
155
159
156
# Split the current weight n chunks - this is useful to split
160
157
# the qkv layers into q, k, v layers for example.
161
158
split_index = current_weight .shape [0 ] // len (original_layer_key_names )
162
- for i , module in enumerate (original_layer_key_names ):
159
+ for i , subparam_name in enumerate (original_layer_key_names ):
160
+ if recurse_getattr (module , subparam_name ) is None :
161
+ # this is for example the case if bias=False is set for a nn.Linear layer
162
+ continue
163
+
163
164
if module not in self .keys_to_ignore :
164
- recurse_setattr (
165
- self .orig_layer ,
166
- module ,
167
- nn .Parameter (current_weight [i * split_index : (i + 1 ) * split_index ]),
168
- )
165
+ parameter = current_weight [i * split_index : (i + 1 ) * split_index ]
166
+ if isinstance (recurse_getattr (module , subparam_name ), torch .nn .Parameter ):
167
+ parameter = torch .nn .Parameter (parameter )
168
+ recurse_setattr (module , subparam_name , parameter )
169
169
elif isinstance (original_layer_key_names , str ):
170
- if module not in self .keys_to_ignore :
171
- recurse_setattr (self .orig_layer , original_layer_key_names , getattr (self , modified_layer_key_names ))
170
+ if recurse_getattr (module , original_layer_key_names ) is None :
171
+ # this is for example the case if bias=False is set for a nn.Linear layer
172
+ continue
173
+
174
+ parameter = getattr (self , modified_layer_key_names )
175
+ if isinstance (recurse_getattr (module , original_layer_key_names ), torch .nn .Parameter ):
176
+ parameter = torch .nn .Parameter (parameter )
177
+ recurse_setattr (module , original_layer_key_names , parameter )
172
178
else :
173
179
raise ValueError (
174
180
f"Invalid type { type (modified_layer_key_names )} for `original_layers_mapping`" ,
175
181
" please use either `str` or `list`." ,
176
182
)
177
-
178
- return self .orig_layer
183
+ return module
0 commit comments