Skip to content

Commit 6df7597

Browse files
authored
Add scaled_dot_product_attention support for decoder models (huggingface#853)
* add gpt2 * add files * refactor tests * support gpt2, better tests * fix * more models * fix gpt neo * small fixes * add opt support * fix tests * add comment * fix mock * fix uninstall * size * last fix
1 parent 1303250 commit 6df7597

12 files changed

+981
-427
lines changed

.github/workflows/test_bettertransformer.yml

+12-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,18 @@ jobs:
2828
- name: Install dependencies
2929
run: |
3030
pip install .[tests]
31-
pip3 install --upgrade torch torchvision torchaudio
31+
pip install --no-cache-dir --upgrade torch torchvision torchaudio
3232
pip install accelerate
33-
- name: Test with unittest
33+
- name: Test on pytorch stable
3434
working-directory: tests
3535
run: |
36-
python -m unittest discover -s bettertransformer -p 'test_*.py'
36+
pytest bettertransformer/test_*.py -s -vvvvv
37+
- name: Install dependencies 2
38+
run: |
39+
pip uninstall -y torch torchvision torchaudio
40+
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
41+
- name: Test on pytorch nightly
42+
working-directory: tests
43+
run: |
44+
pytest bettertransformer/test_*.py -s -vvvvv
45+

optimum/bettertransformer/models/__init__.py

+57-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# limitations under the License.
1414
import warnings
1515

16+
from .decoder_models import (
17+
CodegenAttentionLayerBetterTransformer,
18+
GPT2AttentionLayerBetterTransformer,
19+
GPTNeoAttentionLayerBetterTransformer,
20+
OPTAttentionLayerBetterTransformer,
21+
)
1622
from .encoder_models import (
1723
AlbertLayerBetterTransformer,
1824
BartEncoderLayerBetterTransformer,
@@ -36,18 +42,24 @@ class BetterTransformerManager:
3642
"bert-generation": ("BertGenerationLayer", BertLayerBetterTransformer),
3743
"camembert": ("CamembertLayer", BertLayerBetterTransformer),
3844
"clip": ("CLIPEncoderLayer", CLIPLayerBetterTransformer),
45+
"codegen": ("CodeGenAttention", CodegenAttentionLayerBetterTransformer),
3946
"data2vec-text": ("Data2VecTextLayer", BertLayerBetterTransformer),
4047
"deit": ("DeiTLayer", ViTLayerBetterTransformer),
4148
"distilbert": ("TransformerBlock", DistilBertLayerBetterTransformer),
4249
"electra": ("ElectraLayer", BertLayerBetterTransformer),
4350
"ernie": ("ErnieLayer", BertLayerBetterTransformer),
4451
"fsmt": ("EncoderLayer", FSMTEncoderLayerBetterTransformer),
52+
"gpt2": ("GPT2Attention", GPT2AttentionLayerBetterTransformer),
53+
"gptj": ("GPTJAttention", GPT2AttentionLayerBetterTransformer),
54+
"gpt_neo": ("GPTNeoSelfAttention", GPTNeoAttentionLayerBetterTransformer),
55+
"gpt_neox": ("GPTNeoXAttention", GPT2AttentionLayerBetterTransformer),
4556
"hubert": ("HubertEncoderLayer", Wav2Vec2EncoderLayerBetterTransformer),
4657
"layoutlm": ("LayoutLMLayer", BertLayerBetterTransformer),
4758
"m2m_100": ("M2M100EncoderLayer", MBartEncoderLayerBetterTransformer),
4859
"marian": ("MarianEncoderLayer", BartEncoderLayerBetterTransformer),
4960
"markuplm": ("MarkupLMLayer", BertLayerBetterTransformer),
5061
"mbart": ("MBartEncoderLayer", MBartEncoderLayerBetterTransformer),
62+
"opt": ("OPTAttention", OPTAttentionLayerBetterTransformer),
5163
"rembert": ("RemBertLayer", BertLayerBetterTransformer),
5264
"roberta": ("RobertaLayer", BertLayerBetterTransformer),
5365
"roc_bert": ("RoCBertLayer", BertLayerBetterTransformer),
@@ -73,9 +85,9 @@ class BetterTransformerManager:
7385
}
7486

7587
CAN_NOT_BE_SUPPORTED = {
76-
"deberta-v2": "DeBERTa v2 does not use a regular attention mechanism, which is not suppored in PyTorch's BetterTransformer.",
77-
"glpn": "GLPN has a convolutional layer present in the FFN network, which is not suppored in PyTorch's BetterTransformer.",
78-
"t5": "T5 uses attention bias, which is not suppored in PyTorch's BetterTransformer.",
88+
"deberta-v2": "DeBERTa v2 does not use a regular attention mechanism, which is not supported in PyTorch's BetterTransformer.",
89+
"glpn": "GLPN has a convolutional layer present in the FFN network, which is not supported in PyTorch's BetterTransformer.",
90+
"t5": "T5 uses attention bias, which is not supported in PyTorch's BetterTransformer.",
7991
}
8092

8193
@staticmethod
@@ -100,6 +112,48 @@ def supports(model_type: str) -> bool:
100112
"""
101113
return model_type in BetterTransformerManager.MODEL_MAPPING
102114

115+
@staticmethod
116+
def requires_nested_tensor(model_type: str) -> bool:
117+
"""
118+
Returns True if the BetterTransformer implementation for a given architecture uses nested tensors, False otherwise.
119+
120+
Args:
121+
model_type (`str`):
122+
The model type to check.
123+
"""
124+
if model_type in ["codegen", "gpt2", "gptj", "gpt_neo", "gpt_neox", "opt"]:
125+
return False
126+
else:
127+
return True
128+
129+
@staticmethod
130+
def requires_strict_validation(model_type: str) -> bool:
131+
"""
132+
Returns True if the architecture requires to make sure all conditions of `validate_bettertransformer` are met.
133+
134+
Args:
135+
model_type (`str`):
136+
The model type to check.
137+
"""
138+
if model_type in ["codegen", "gpt2", "gptj", "gpt_neo", "gpt_neox", "opt"]:
139+
return False
140+
else:
141+
return True
142+
143+
@staticmethod
144+
def requires_torch_20(model_type: str) -> bool:
145+
"""
146+
Returns True if the architecture requires PyTorch 2.0 to be used with BetterTransformer.
147+
148+
Args:
149+
model_type (`str`):
150+
The model type to check.
151+
"""
152+
if model_type in ["codegen", "gpt2", "gptj", "gpt_neo", "gpt_neox", "opt"]:
153+
return True
154+
else:
155+
return False
156+
103157

104158
class warn_uncompatible_save(object):
105159
def __init__(self, callback):

0 commit comments

Comments
 (0)