Skip to content

Commit bde5115

Browse files
authored
Add bettertransformer reverse transform (huggingface#868)
* add reverse * all tests should pass * add adapted from * fix tests * fix test
1 parent 5236ee6 commit bde5115

14 files changed

+617
-654
lines changed

optimum/bettertransformer/models/base.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING, Optional
14+
from typing import TYPE_CHECKING
1515

1616

1717
if TYPE_CHECKING:
@@ -20,7 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222

23-
from ...utils import logging, recurse_setattr
23+
from ...utils import logging, recurse_getattr, recurse_setattr
2424

2525

2626
KNOWN_ACTIVATION_ATTRIBUTES = ["hidden_act", "activation", "act_fn", "activation_function"]
@@ -38,7 +38,6 @@ class BetterTransformerBaseLayer(nn.Module):
3838
def __init__(
3939
self,
4040
config: "PretrainedConfig",
41-
orig_layer: Optional[nn.Module] = None,
4241
):
4342
r"""
4443
Base layer for `BetterTransformer` integration. This class is used to wrap all the necessary
@@ -57,6 +56,7 @@ def __init__(
5756
self.embed_dim = None
5857
self.num_layers = None
5958
self.original_layers_mapping = {}
59+
self.module_mapping = None
6060
self.is_decoder = False
6161
# Some models does not have some attributes thus needs to be ignored
6262
# e.g. whisper does not have self_attn.k_proj.bias but has self_attn.v_proj.bias & self_attn.q_proj.bias
@@ -84,14 +84,6 @@ def __init__(
8484
self.num_layers = getattr(config, attr)
8585
break
8686

87-
# TODO: re-enable once fixed
88-
# if orig_layer is not None:
89-
# # Last step, store the old module skeleton by copying the old module and putting
90-
# # it on the `meta` device.
91-
# self.orig_layer = deepcopy(orig_layer).to("meta")
92-
# else:
93-
# self.orig_layer = orig_layer
94-
9587
def validate_bettertransformer(self):
9688
r"""
9789
A wrapper function to validate the `BetterTransformer` implementation. Implements most relevant checks
@@ -147,32 +139,45 @@ def forward_checker(self, *args, **kwargs):
147139
" Please use `model.eval()` before running the model.",
148140
)
149141

150-
def _revert_back_to_original_module(self):
151-
r"""
152-
A wrapper function to replace the current layer with the previous non-BetterTransformer
153-
layer.
154-
"""
142+
def _revert(self, module: torch.nn.Module) -> torch.nn.Module:
143+
if self.module_mapping is not None:
144+
if "" in self.module_mapping.values():
145+
for bt_module_attr_name, value in self.module_mapping.items():
146+
if value == "":
147+
module = getattr(self, bt_module_attr_name)
148+
return module
149+
else:
150+
raise NotImplementedError("replacing a submodule in revert is not supported")
151+
155152
for modified_layer_key_names, original_layer_key_names in self.original_layers_mapping.items():
156153
if isinstance(original_layer_key_names, list):
157154
current_weight = getattr(self, modified_layer_key_names)
158155

159156
# Split the current weight n chunks - this is useful to split
160157
# the qkv layers into q, k, v layers for example.
161158
split_index = current_weight.shape[0] // len(original_layer_key_names)
162-
for i, module in enumerate(original_layer_key_names):
159+
for i, subparam_name in enumerate(original_layer_key_names):
160+
if recurse_getattr(module, subparam_name) is None:
161+
# this is for example the case if bias=False is set for a nn.Linear layer
162+
continue
163+
163164
if module not in self.keys_to_ignore:
164-
recurse_setattr(
165-
self.orig_layer,
166-
module,
167-
nn.Parameter(current_weight[i * split_index : (i + 1) * split_index]),
168-
)
165+
parameter = current_weight[i * split_index : (i + 1) * split_index]
166+
if isinstance(recurse_getattr(module, subparam_name), torch.nn.Parameter):
167+
parameter = torch.nn.Parameter(parameter)
168+
recurse_setattr(module, subparam_name, parameter)
169169
elif isinstance(original_layer_key_names, str):
170-
if module not in self.keys_to_ignore:
171-
recurse_setattr(self.orig_layer, original_layer_key_names, getattr(self, modified_layer_key_names))
170+
if recurse_getattr(module, original_layer_key_names) is None:
171+
# this is for example the case if bias=False is set for a nn.Linear layer
172+
continue
173+
174+
parameter = getattr(self, modified_layer_key_names)
175+
if isinstance(recurse_getattr(module, original_layer_key_names), torch.nn.Parameter):
176+
parameter = torch.nn.Parameter(parameter)
177+
recurse_setattr(module, original_layer_key_names, parameter)
172178
else:
173179
raise ValueError(
174180
f"Invalid type {type(modified_layer_key_names)} for `original_layers_mapping`",
175181
" please use either `str` or `list`.",
176182
)
177-
178-
return self.orig_layer
183+
return module

0 commit comments

Comments
 (0)