Skip to content

Commit 7a4464d

Browse files
authored
Fix compatibility with transformers v4.40.0 (#682)
* bump transformers version * trigger test * fix compatibility with latest transformers release * format * fix * tmp add test * remove tmp tests
1 parent a0dc06c commit 7a4464d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

optimum/exporters/openvino/model_patcher.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,17 @@ def __exit__(self, exc_type, exc_value, traceback):
291291
# adopted from
292292
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293293
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294-
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, **kwargs):
294+
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
295295
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
296296

297-
# for compatibility with https://github.com/huggingface/transformers/pull/30047
298-
current_length = kwargs.get("current_length", cache_position[-1])
297+
if self.config._attn_implementation == "sdpa" and past_seen_tokens is not None:
298+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
299+
# in order to dispatch on Flash Attention 2.
300+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
301+
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
302+
):
303+
return None
304+
299305
dtype, device = input_tensor.dtype, input_tensor.device
300306

301307
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
@@ -305,7 +311,13 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
305311
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
306312
target_length = self.config.max_position_embeddings
307313
else: # dynamic cache
308-
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
314+
if past_seen_tokens is not None:
315+
current_length = past_seen_tokens + sequence_length + 1
316+
# TODO : remove after support of transformers >= v4.40.0
317+
else:
318+
current_length = cache_position[-1] + 1
319+
320+
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
309321

310322
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
311323
if sequence_length != 1:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
INSTALL_REQUIRE = [
3030
"torch>=1.11",
31-
"transformers>=4.36.0,<4.40.0",
31+
"transformers>=4.36.0,<4.41.0",
3232
"optimum~=1.19",
3333
"datasets>=1.4.0",
3434
"sentencepiece",

0 commit comments

Comments
 (0)