Skip to content

Commit 920341a

Browse files
committed
fix compatibility with latest transformers release
1 parent fd4e281 commit 920341a

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

optimum/exporters/openvino/model_patcher.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,9 @@ def __exit__(self, exc_type, exc_value, traceback):
291291
# adopted from
292292
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
293293
# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
294-
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, **kwargs):
294+
def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
295295
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
296296

297-
# for compatibility with https://github.com/huggingface/transformers/pull/30047
298-
current_length = kwargs.get("current_length", cache_position[-1])
299297
dtype, device = input_tensor.dtype, input_tensor.device
300298

301299
# using minimum from dtype with larger bandwith (floa32) may lead to overflow
@@ -305,7 +303,14 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
305303
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
306304
target_length = self.config.max_position_embeddings
307305
else: # dynamic cache
308-
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
306+
if past_seen_tokens is not None:
307+
current_length = past_seen_tokens + sequence_length + 1
308+
# TODO : remove after support of transformers >= v4.40.0
309+
else:
310+
current_length = cache_position[-1] + 1
311+
312+
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length
313+
309314

310315
causal_mask = torch.full((sequence_length, target_length), fill_value=1, dtype=dtype, device=device) * min_dtype
311316
if sequence_length != 1:

0 commit comments

Comments
 (0)