Skip to content

Commit c8cf353

Browse files
authoredOct 9, 2023
Update attention.py (#1416)
* Update attention.py modify the code about bigcode. This modification makes the KV cache with multiple new tokens works well. * consider batch size = 1 * Update attention.py * def kv_seq_len
1 parent c98cb87 commit c8cf353

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed
 

‎optimum/bettertransformer/models/attention.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def gpt_bigcode_wrapped_scaled_dot_product(
699699
# MHA models: (batch_size, num_heads, query_length, head_dim)
700700
query_shape = query.shape
701701
batch_size = query_shape[0]
702+
kv_seq_len = key.shape[-2]
702703

703704
if self.multi_query:
704705
query_length = query_shape[1]
@@ -725,30 +726,34 @@ def gpt_bigcode_wrapped_scaled_dot_product(
725726
key = key.expand(-1, self.num_heads, -1, -1)
726727
value = value.expand(-1, self.num_heads, -1, -1)
727728

728-
if batch_size == 1 or self.training:
729-
if query_length > 1:
730-
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
731-
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
732-
)
733-
else:
734-
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
735-
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
736-
)
729+
# We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
730+
if self.training:
731+
is_causal = True
732+
attn_mask = None
733+
elif batch_size == 1 and query_length == 1:
734+
is_causal = False
735+
attn_mask = None
736+
elif batch_size == 1 and kv_seq_len == query_length:
737+
is_causal = True
738+
attn_mask = None
739+
elif attention_mask is not None:
740+
mask_value = self._get_mask_value(query.device, query.dtype)
741+
742+
# gpt_bigcode has the bad taste to use a causal mask a
743+
# [batch_size, target_length, 1, source_length] which is different from
744+
# **all** other architectures and not compatible with SDPA.
745+
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
746+
# but it is probably not worth it.
747+
attention_mask = attention_mask.transpose(1, 2)
748+
attn_mask = torch.where(attention_mask, 0.0, mask_value)
749+
is_causal = False
737750
else:
738-
if attention_mask is not None:
739-
mask_value = self._get_mask_value(query.device, query.dtype)
751+
attn_mask = None
752+
is_causal = True
740753

741-
# gpt_bigcode has the bad taste to use a causal mask a
742-
# [batch_size, target_length, 1, source_length] which is different from
743-
# **all** other architectures and not compatible with SDPA.
744-
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
745-
# but it is probably not worth it.
746-
attention_mask = attention_mask.transpose(1, 2)
747-
attention_mask = torch.where(attention_mask, 0.0, mask_value)
748-
749-
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
750-
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
751-
)
754+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
755+
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
756+
)
752757

753758
if self.multi_query:
754759
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)

0 commit comments

Comments
 (0)
Please sign in to comment.