Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding BetterTransformer support for ProphetNet #648

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
WhisperEncoderLayerBetterTransformer,
ProphetNetEncoderLayerBetterTransformer
)


class BetterTransformerManager:
MODEL_MAPPING = {
"albert": ("AlbertLayer", AlbertLayerBetterTransformer),
Expand All @@ -47,6 +47,7 @@ class BetterTransformerManager:
"m2m_100": ("M2M100EncoderLayer", MBartEncoderLayerBetterTransformer),
"markuplm": ("MarkupLMLayer", BertLayerBetterTransformer),
"mbart": ("MBartEncoderLayer", MBartEncoderLayerBetterTransformer),
"prophetnet": ("ProphetNetEncoderLayer", ProphetNetEncoderLayerBetterTransformer),
"rembert": ("RemBertLayer", BertLayerBetterTransformer),
"roberta": ("RobertaLayer", BertLayerBetterTransformer),
"splinter": ("SplinterLayer", BertLayerBetterTransformer),
Expand All @@ -58,7 +59,7 @@ class BetterTransformerManager:
"wav2vec2": ("Wav2Vec2EncoderLayer", Wav2Vec2EncoderLayerBetterTransformer),
"whisper": ("WhisperEncoderLayer", WhisperEncoderLayerBetterTransformer),
"xlm-roberta": ("XLMRobertaLayer", BertLayerBetterTransformer),
"yolos": ("YolosLayer", ViTLayerBetterTransformer),
"yolos": ("YolosLayer", ViTLayerBetterTransformer)
}

EXCLUDE_FROM_TRANSFORM = {
Expand Down
117 changes: 115 additions & 2 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,6 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states, attention_mask)


class CLIPLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, layer, config):
r"""
Expand Down Expand Up @@ -1155,7 +1154,7 @@ def forward(self, hidden_states, attention_mask, *_, **__):
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

# we expect attention_mask to be None in the vision model
if attention_mask is not None:
raise ValueError(
Expand Down Expand Up @@ -1192,3 +1191,117 @@ def _get_activation_function(self, config: "PretrainedConfig"):
return config.vision_config.hidden_act
else:
return config.hidden_act

class ProphetNetEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, prnt_layer, config):
r"""
A conversion of the ProphetNet Encoder layer to its `BetterTransformer` implementation.

Args:
prnt_layer (`torch.nn.Module`):
The original ProphetNetEncoderLayer where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
prnt_layer.self_attn.query_proj.weight,
prnt_layer.self_attn.key_proj.weight,
prnt_layer.self_attn.value_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
prnt_layer.self_attn.query_proj.bias,
prnt_layer.self_attn.key_proj.bias,
prnt_layer.self_attn.value_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = prnt_layer.self_attn.out_proj.weight
self.out_proj_bias = prnt_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = prnt_layer.feed_forward.intermediate.weight
self.linear1_bias = prnt_layer.feed_forward.intermediate.bias

# Linear layer 2
self.linear2_weight = prnt_layer.feed_forward.output.weight
self.linear2_bias = prnt_layer.feed_forward.output.bias

# Layer norm 1
self.norm1_eps = prnt_layer.self_attn_layer_norm.eps
self.norm1_weight = prnt_layer.self_attn_layer_norm.weight
self.norm1_bias = prnt_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = prnt_layer.feed_forward_layer_norm.eps
self.norm2_weight = prnt_layer.feed_forward_layer_norm.weight
self.norm2_bias = prnt_layer.feed_forward_layer_norm.bias

# Model hyper parameters
self.num_heads = prnt_layer.self_attn.num_attn_heads
self.embed_dim = prnt_layer.self_attn.head_dim * self.num_heads

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)

attention_mask = attention_mask.unsqueeze(1).unique(dim=0).squeeze()

if hidden_states.shape[0] != attention_mask.shape[0]:
hidden_states = hidden_states.transpose(1, 0)
Comment on lines +1276 to +1277
Copy link
Contributor

@fxmarty fxmarty Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the error you get is due to those you could try to remove, not sure.

Copy link
Contributor Author

@adit299 adit299 Jan 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried that.. no luck, those tests are still failing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using print statements, I can see that hidden_states is of shape: torch.Size([2, 4, 16]) and attention_mask is of shape: torch.Size([8, 4]). So there appears to be a discrepancy in the batch_size value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @adit299 I'll look asap, probably on Friday. In the meanwhile you can compare with the shapes for e.g. Bert to see how it differs.


if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states, attention_mask)

1 change: 1 addition & 0 deletions tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"hf-internal-testing/tiny-random-FSMTModel",
"hf-internal-testing/tiny-random-mbart",
"hf-internal-testing/tiny-random-nllb",
"hf-internal-testing/tiny-random-ProphetNetModel"
]


Expand Down