13
13
# limitations under the License.
14
14
import warnings
15
15
16
+ from .decoder_models import (
17
+ CodegenAttentionLayerBetterTransformer ,
18
+ GPT2AttentionLayerBetterTransformer ,
19
+ GPTNeoAttentionLayerBetterTransformer ,
20
+ OPTAttentionLayerBetterTransformer ,
21
+ )
16
22
from .encoder_models import (
17
23
AlbertLayerBetterTransformer ,
18
24
BartEncoderLayerBetterTransformer ,
@@ -36,18 +42,24 @@ class BetterTransformerManager:
36
42
"bert-generation" : ("BertGenerationLayer" , BertLayerBetterTransformer ),
37
43
"camembert" : ("CamembertLayer" , BertLayerBetterTransformer ),
38
44
"clip" : ("CLIPEncoderLayer" , CLIPLayerBetterTransformer ),
45
+ "codegen" : ("CodeGenAttention" , CodegenAttentionLayerBetterTransformer ),
39
46
"data2vec-text" : ("Data2VecTextLayer" , BertLayerBetterTransformer ),
40
47
"deit" : ("DeiTLayer" , ViTLayerBetterTransformer ),
41
48
"distilbert" : ("TransformerBlock" , DistilBertLayerBetterTransformer ),
42
49
"electra" : ("ElectraLayer" , BertLayerBetterTransformer ),
43
50
"ernie" : ("ErnieLayer" , BertLayerBetterTransformer ),
44
51
"fsmt" : ("EncoderLayer" , FSMTEncoderLayerBetterTransformer ),
52
+ "gpt2" : ("GPT2Attention" , GPT2AttentionLayerBetterTransformer ),
53
+ "gptj" : ("GPTJAttention" , GPT2AttentionLayerBetterTransformer ),
54
+ "gpt_neo" : ("GPTNeoSelfAttention" , GPTNeoAttentionLayerBetterTransformer ),
55
+ "gpt_neox" : ("GPTNeoXAttention" , GPT2AttentionLayerBetterTransformer ),
45
56
"hubert" : ("HubertEncoderLayer" , Wav2Vec2EncoderLayerBetterTransformer ),
46
57
"layoutlm" : ("LayoutLMLayer" , BertLayerBetterTransformer ),
47
58
"m2m_100" : ("M2M100EncoderLayer" , MBartEncoderLayerBetterTransformer ),
48
59
"marian" : ("MarianEncoderLayer" , BartEncoderLayerBetterTransformer ),
49
60
"markuplm" : ("MarkupLMLayer" , BertLayerBetterTransformer ),
50
61
"mbart" : ("MBartEncoderLayer" , MBartEncoderLayerBetterTransformer ),
62
+ "opt" : ("OPTAttention" , OPTAttentionLayerBetterTransformer ),
51
63
"rembert" : ("RemBertLayer" , BertLayerBetterTransformer ),
52
64
"roberta" : ("RobertaLayer" , BertLayerBetterTransformer ),
53
65
"roc_bert" : ("RoCBertLayer" , BertLayerBetterTransformer ),
@@ -73,9 +85,9 @@ class BetterTransformerManager:
73
85
}
74
86
75
87
CAN_NOT_BE_SUPPORTED = {
76
- "deberta-v2" : "DeBERTa v2 does not use a regular attention mechanism, which is not suppored in PyTorch's BetterTransformer." ,
77
- "glpn" : "GLPN has a convolutional layer present in the FFN network, which is not suppored in PyTorch's BetterTransformer." ,
78
- "t5" : "T5 uses attention bias, which is not suppored in PyTorch's BetterTransformer." ,
88
+ "deberta-v2" : "DeBERTa v2 does not use a regular attention mechanism, which is not supported in PyTorch's BetterTransformer." ,
89
+ "glpn" : "GLPN has a convolutional layer present in the FFN network, which is not supported in PyTorch's BetterTransformer." ,
90
+ "t5" : "T5 uses attention bias, which is not supported in PyTorch's BetterTransformer." ,
79
91
}
80
92
81
93
@staticmethod
@@ -100,6 +112,48 @@ def supports(model_type: str) -> bool:
100
112
"""
101
113
return model_type in BetterTransformerManager .MODEL_MAPPING
102
114
115
+ @staticmethod
116
+ def requires_nested_tensor (model_type : str ) -> bool :
117
+ """
118
+ Returns True if the BetterTransformer implementation for a given architecture uses nested tensors, False otherwise.
119
+
120
+ Args:
121
+ model_type (`str`):
122
+ The model type to check.
123
+ """
124
+ if model_type in ["codegen" , "gpt2" , "gptj" , "gpt_neo" , "gpt_neox" , "opt" ]:
125
+ return False
126
+ else :
127
+ return True
128
+
129
+ @staticmethod
130
+ def requires_strict_validation (model_type : str ) -> bool :
131
+ """
132
+ Returns True if the architecture requires to make sure all conditions of `validate_bettertransformer` are met.
133
+
134
+ Args:
135
+ model_type (`str`):
136
+ The model type to check.
137
+ """
138
+ if model_type in ["codegen" , "gpt2" , "gptj" , "gpt_neo" , "gpt_neox" , "opt" ]:
139
+ return False
140
+ else :
141
+ return True
142
+
143
+ @staticmethod
144
+ def requires_torch_20 (model_type : str ) -> bool :
145
+ """
146
+ Returns True if the architecture requires PyTorch 2.0 to be used with BetterTransformer.
147
+
148
+ Args:
149
+ model_type (`str`):
150
+ The model type to check.
151
+ """
152
+ if model_type in ["codegen" , "gpt2" , "gptj" , "gpt_neo" , "gpt_neox" , "opt" ]:
153
+ return True
154
+ else :
155
+ return False
156
+
103
157
104
158
class warn_uncompatible_save (object ):
105
159
def __init__ (self , callback ):
0 commit comments