diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 0c79d8c755..ce40d64c4d 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -46,8 +46,9 @@ The list of supported model below: - [CLIP](https://arxiv.org/abs/2103.00020) - [CodeGen](https://arxiv.org/abs/2203.13474) - [Data2VecText](https://arxiv.org/abs/2202.03555) -- [DistilBert](https://arxiv.org/abs/1910.01108) - [DeiT](https://arxiv.org/abs/2012.12877) +- [DERT] (https://arxiv.org/abs/2005.12872) +- [DistilBert](https://arxiv.org/abs/1910.01108) - [Electra](https://arxiv.org/abs/2003.10555) - [Ernie](https://arxiv.org/abs/1904.09223) - [FSMT](https://arxiv.org/abs/1907.06616) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 286e90231f..7f115c27b6 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -38,6 +38,7 @@ BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, CLIPLayerBetterTransformer, + DetrEncoderLayerBetterTransformer, DistilBertLayerBetterTransformer, FSMTEncoderLayerBetterTransformer, MBartEncoderLayerBetterTransformer, @@ -67,21 +68,22 @@ class BetterTransformerManager: "bert": {"BertLayer": BertLayerBetterTransformer}, "bert-generation": {"BertGenerationLayer": BertLayerBetterTransformer}, "blenderbot": {"BlenderbotAttention": BlenderbotAttentionLayerBetterTransformer}, + "blip-2": {"T5Attention": T5AttentionLayerBetterTransformer}, "bloom": {"BloomAttention": BloomAttentionLayerBetterTransformer}, "camembert": {"CamembertLayer": BertLayerBetterTransformer}, - "blip-2": {"T5Attention": T5AttentionLayerBetterTransformer}, "clip": {"CLIPEncoderLayer": CLIPLayerBetterTransformer}, "codegen": {"CodeGenAttention": CodegenAttentionLayerBetterTransformer}, "data2vec-text": {"Data2VecTextLayer": BertLayerBetterTransformer}, "deit": {"DeiTLayer": ViTLayerBetterTransformer}, + "detr": {"DetrEncoderLayer": DetrEncoderLayerBetterTransformer}, "distilbert": {"TransformerBlock": DistilBertLayerBetterTransformer}, "electra": {"ElectraLayer": BertLayerBetterTransformer}, "ernie": {"ErnieLayer": BertLayerBetterTransformer}, - "fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer}, "falcon": {"FalconAttention": FalconAttentionLayerBetterTransformer}, + "fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer}, "gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer}, - "gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer}, "gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer}, + "gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer}, "gpt_neo": {"GPTNeoSelfAttention": GPTNeoAttentionLayerBetterTransformer}, "gpt_neox": {"GPTNeoXAttention": GPTNeoXAttentionLayerBetterTransformer}, "hubert": {"HubertEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer}, @@ -99,8 +101,8 @@ class BetterTransformerManager: "mbart": {"MBartEncoderLayer": MBartEncoderLayerBetterTransformer}, "opt": {"OPTAttention": OPTAttentionLayerBetterTransformer}, "pegasus": {"PegasusAttention": PegasusAttentionLayerBetterTransformer}, - "rembert": {"RemBertLayer": BertLayerBetterTransformer}, "prophetnet": {"ProphetNetEncoderLayer": ProphetNetEncoderLayerBetterTransformer}, + "rembert": {"RemBertLayer": BertLayerBetterTransformer}, "roberta": {"RobertaLayer": BertLayerBetterTransformer}, "roc_bert": {"RoCBertLayer": BertLayerBetterTransformer}, "roformer": {"RoFormerLayer": BertLayerBetterTransformer}, diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 20f7f4de50..22f434c528 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -11,20 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING import torch import torch.nn as nn import torch.nn.functional as F +from transformers import PretrainedConfig from transformers.activations import ACT2FN from .base import BetterTransformerBaseLayer -if TYPE_CHECKING: - from transformers import PretrainedConfig - - class AlbertLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module): def __init__(self, albert_layer, config): r""" @@ -1189,6 +1185,123 @@ def forward(self, hidden_states, output_attentions: bool, *_, **__): return (hidden_states,) +class DetrEncoderLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module): + def __init__(self, detr_layer, config): + r""" + A simple conversion of the DetrEncoderLayer to its `BetterTransformer` implementation. + + Args: + detr_layer (`torch.nn.Module`): + The original `DetrEncoderLayer` where the weights needs to be retrieved. + """ + super().__init__(config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + detr_layer.self_attn.q_proj.weight, + detr_layer.self_attn.k_proj.weight, + detr_layer.self_attn.v_proj.weight, + ] + ) + ) + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + detr_layer.self_attn.q_proj.bias, + detr_layer.self_attn.k_proj.bias, + detr_layer.self_attn.v_proj.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = detr_layer.self_attn.out_proj.weight + self.out_proj_bias = detr_layer.self_attn.out_proj.bias + + # Linear layer 1 + self.linear1_weight = detr_layer.fc1.weight + self.linear1_bias = detr_layer.fc1.bias + + # Linear layer 2 + self.linear2_weight = detr_layer.fc2.weight + self.linear2_bias = detr_layer.fc2.bias + + # Layer norm 1 + self.norm1_eps = detr_layer.self_attn_layer_norm.eps + self.norm1_weight = detr_layer.self_attn_layer_norm.weight + self.norm1_bias = detr_layer.self_attn_layer_norm.bias + + # Layer norm 2 + self.norm2_eps = detr_layer.final_layer_norm.eps + self.norm2_weight = detr_layer.final_layer_norm.weight + self.norm2_bias = detr_layer.final_layer_norm.bias + + # Model hyper parameters + self.num_heads = detr_layer.self_attn.num_heads + self.embed_dim = detr_layer.self_attn.embed_dim + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + self.norm_first = True + + self.original_layers_mapping = { + "in_proj_weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], + "in_proj_bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], + "out_proj_weight": "self_attn.out_proj.weight", + "out_proj_bias": "self_attn.out_proj.bias", + "linear1_weight": "fc1.weight", + "linear1_bias": "fc1.bias", + "linear2_weight": "fc2.weight", + "linear2_bias": "fc2.bias", + "norm1_weight": "self_attn_layer_norm.weight", + "norm1_bias": "self_attn_layer_norm.bias", + "norm2_weight": "final_layer_norm.weight", + "norm2_bias": "final_layer_norm.bias", + } + + self.validate_bettertransformer() + + def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + + else: + raise NotImplementedError( + "Training and Autocast are not implemented for BetterTransformer + Detr. Please open an issue." + ) + + return (hidden_states,) + + class ViltLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module): def __init__(self, vilt_layer, config): r""" diff --git a/setup.py b/setup.py index 7a7f454684..3936d35d3d 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "torchaudio", "einops", "invisible-watermark", + "timm" ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"] diff --git a/tests/bettertransformer/test_vision.py b/tests/bettertransformer/test_vision.py index ea04936fab..8c61ea09ba 100644 --- a/tests/bettertransformer/test_vision.py +++ b/tests/bettertransformer/test_vision.py @@ -27,7 +27,18 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas r""" Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin` """ - SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"] + SUPPORTED_ARCH = [ + "blip-2", + "clip", + "clip_text_model", + "deit", + "detr", + "vilt", + "vit", + "vit_mae", + "vit_msn", + "yolos", + ] def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs): if model_type == "vilt": @@ -57,6 +68,14 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc if model_type == "blip-2": inputs["decoder_input_ids"] = inputs["input_ids"] + elif model_type == "detr": + # Assuming detr just needs an image + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + feature_extractor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-DetrModel") + inputs = feature_extractor(images=image, return_tensors="pt") + else: url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 113c59f63c..46306e40d3 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -40,6 +40,7 @@ "codegen": "hf-internal-testing/tiny-random-CodeGenModel", "data2vec-text": "hf-internal-testing/tiny-random-Data2VecTextModel", "deit": "hf-internal-testing/tiny-random-deit", + "detr": "hf-internal-testing/tiny-random-DetrModel", "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "ernie": "hf-internal-testing/tiny-random-ErnieModel",