Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BetterTransformer support for FlavaModel #538

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e2fa9b4
add BetterTransformer support for FlavaModel and test
Dec 2, 2022
c76b1c4
Update optimum/bettertransformer/models/encoder_models.py
katiele47 Dec 3, 2022
fadd2ed
Merge branch 'main' of https://github.com/huggingface/optimum into ad…
Dec 3, 2022
f91a928
Merge branch 'add-better-transformers-support-for-flava' of https://g…
Dec 3, 2022
072420c
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 5, 2022
4966b96
Merge branch 'main' of https://github.com/huggingface/optimum into ad…
Dec 5, 2022
6c80658
Merge branch 'add-better-transformers-support-for-flava' of https://g…
Dec 5, 2022
2a4b790
Optimum ONNX Runtime API improvement (#515)
michaelbenayoun Dec 6, 2022
8b559db
Add IO binding support for custom ORTModel (#447)
JingyaHuang Dec 6, 2022
d026fdc
fix import (#553)
fxmarty Dec 7, 2022
037467d
Update readme (#550)
echarlaix Dec 7, 2022
f6eb417
Refactor of 2 functions used in ORTModel (#551)
michaelbenayoun Dec 7, 2022
382077d
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 7, 2022
422f3d7
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 7, 2022
fac2694
Update readme (#556)
echarlaix Dec 7, 2022
08d7917
applied make style
Dec 7, 2022
6063fc4
Fix ORTTrainer wrapper duplication / PyTorch evaluate / update with t…
JingyaHuang Dec 7, 2022
d169fc3
fix test
younesbelkada Dec 8, 2022
1588a2e
Add CLIP BetterTransformer (#534)
fxmarty Dec 8, 2022
86375c4
Fix flaky BetterTransformer test (#564)
fxmarty Dec 8, 2022
0970ec4
Support decoder generated with `--for-ort` from `optimum.exporters.on…
fxmarty Dec 8, 2022
5c90cf1
enable FP16Optimizer for fp16 deepspeed training. (#547)
AdamLouly Dec 8, 2022
1521f1b
fixed merge conflict due to rebase with upstream main
Dec 9, 2022
f48da36
merge conflict clip and flava
Dec 9, 2022
0b8ed50
installed missing dependencies
Dec 9, 2022
a012ec7
applied make style
Dec 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The list of supported model below:
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [Flava](https://arxiv.org/abs/2112.04482)
- [FSMT](https://arxiv.org/abs/1907.06616)
- [HuBERT](https://arxiv.org/pdf/2106.07447.pdf)
- [LayoutLM](https://arxiv.org/abs/1912.13318)
Expand Down
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FlavaLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
Expand Down Expand Up @@ -74,6 +75,7 @@
# FSMTModel:
"EncoderLayer": FSMTEncoderLayerBetterTransformer,
"ViltLayer": ViltLayerBetterTransformer,
"FlavaLayer": FlavaLayerBetterTransformer,
}


Expand Down
96 changes: 96 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,3 +1068,99 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states, attention_mask)


class FlavaLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, flava_layer, config):
r"""
A simple conversion of the FlavaLayer to its `BetterTransformer` implementation.

Args:
flava_layer (`torch.nn.Module`):
The original `FlavaLayer` where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.weight,
flava_layer.attention.attention.key.weight,
flava_layer.attention.attention.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.bias,
flava_layer.attention.attention.key.bias,
flava_layer.attention.attention.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = flava_layer.attention.output.dense.weight
self.out_proj_bias = flava_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = flava_layer.intermediate.dense.weight
self.linear1_bias = flava_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = flava_layer.output.dense.weight
self.linear2_bias = flava_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = flava_layer.layernorm_before.eps
self.norm1_weight = flava_layer.layernorm_before.weight
self.norm1_bias = flava_layer.layernorm_before.bias

# Layer norm 2
self.norm2_eps = flava_layer.layernorm_after.eps
self.norm2_weight = flava_layer.layernorm_after.weight
self.norm2_bias = flava_layer.layernorm_after.bias

# Model hyper parameters
self.num_heads = flava_layer.attention.attention.num_attention_heads
self.embed_dim = int(flava_layer.attention.attention.attention_head_size * self.num_heads)

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True

self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)
18 changes: 18 additions & 0 deletions tests/bettertransformer/test_bettertransformer_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

ALL_VISION_TEXT_MODELS_TO_TEST = [
"hf-internal-testing/tiny-vilt-random-vqa",
"hf-internal-testing/tiny-random-FlavaModel",
]


Expand Down Expand Up @@ -66,3 +67,20 @@ def prepare_inputs_for_class(self, model_id=None):
processor = AutoProcessor.from_pretrained(model_id)
inputs = processor(image, text, return_tensors="pt")
return inputs


class BetterTransformersFlavaTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Testing suite for Vision and Text Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
all_models_to_test = ALL_VISION_TEXT_MODELS_TO_TEST

def prepare_inputs_for_class(self, model_id=None):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = "How many cats are there?"

# Model takes image and text as input
processor = AutoProcessor.from_pretrained(model_id)
inputs = processor(image, text, return_tensors="pt")
return inputs