Skip to content

Commit d1bcdf7

Browse files
Transformers 4.48 (#2158)
* test * testing tensor cache x) * fix logger * condition cache class usage * update opset for beit and data2vec vision and skip flattened/fused pkv (e.g. gpt bigcode) * style * fix args patcher * fix modernbert testing * adaot to new whisper returned generation length * fix is_causal in transformers * fix modernbert failures * style * traceable cache * use pkv index * add version gard and clean up other model patcher version gards * patch sdpa attention in optimum for now * remove modernbert condition * style * fix MistralModelPatcher * correctly patch gpt2 in vision encoder decoder * patch sdpa attention forward everywhere * fix gpt2 cross attention in seq2seq as well * moved traceable cache to a file for simplicity of model patcher * Apply suggestions from code review * style * fix
1 parent 50531a4 commit d1bcdf7

File tree

7 files changed

+272
-115
lines changed

7 files changed

+272
-115
lines changed
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
from typing import Any, Dict, Optional, Tuple
3+
4+
import torch
5+
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873
11+
class TraceableCache:
12+
"""
13+
Base, abstract class for all caches. The actual data structure is specific to each subclass.
14+
"""
15+
16+
def __init__(self):
17+
super().__init__()
18+
19+
def update(
20+
self,
21+
key_states: torch.Tensor,
22+
value_states: torch.Tensor,
23+
layer_idx: int,
24+
cache_kwargs: Optional[Dict[str, Any]] = None,
25+
) -> Tuple[torch.Tensor, torch.Tensor]:
26+
"""
27+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
28+
29+
Parameters:
30+
key_states (`torch.Tensor`):
31+
The new key states to cache.
32+
value_states (`torch.Tensor`):
33+
The new value states to cache.
34+
layer_idx (`int`):
35+
The index of the layer to cache the states for.
36+
cache_kwargs (`Dict[str, Any]`, `optional`):
37+
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
38+
cache to be created.
39+
40+
Return:
41+
A tuple containing the updated key and value states.
42+
"""
43+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
44+
45+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
46+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
47+
# TODO: deprecate this function in favor of `cache_position`
48+
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
49+
50+
# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
51+
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
52+
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
53+
# we change naming to be more explicit
54+
def get_max_length(self) -> Optional[int]:
55+
logger.warning_once(
56+
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
57+
"Calling `get_max_cache()` will raise error from v4.48"
58+
)
59+
return self.get_max_cache_shape()
60+
61+
def get_max_cache_shape(self) -> Optional[int]:
62+
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
63+
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
64+
65+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
66+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
67+
# Cache without size limit -> all cache is usable
68+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
69+
# length, we will need to evict part of the cache (and thus not all cache is usable)
70+
max_length = self.get_max_cache_shape()
71+
previous_seq_length = self.get_seq_length(layer_idx)
72+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
73+
return max_length - new_seq_length
74+
return previous_seq_length
75+
76+
def reorder_cache(self, beam_idx: torch.LongTensor):
77+
"""Reorders the cache for beam search, given the selected beam indices."""
78+
for layer_idx in range(len(self.key_cache)):
79+
if self.key_cache[layer_idx] != []:
80+
device = self.key_cache[layer_idx].device
81+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
82+
if self.value_cache[layer_idx] != []:
83+
device = self.value_cache[layer_idx].device
84+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
85+
86+
@property
87+
def seen_tokens(self):
88+
logger.warning_once(
89+
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
90+
"model input instead."
91+
)
92+
if hasattr(self, "_seen_tokens"):
93+
return self._seen_tokens
94+
else:
95+
return None

optimum/exporters/onnx/model_configs.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ class DeiTOnnxConfig(ViTOnnxConfig):
847847

848848

849849
class BeitOnnxConfig(ViTOnnxConfig):
850-
DEFAULT_ONNX_OPSET = 11
850+
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
851851

852852

853853
class ConvNextOnnxConfig(ViTOnnxConfig):
@@ -1577,13 +1577,12 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig):
15771577

15781578

15791579
class Data2VecVisionOnnxConfig(ViTOnnxConfig):
1580-
DEFAULT_ONNX_OPSET = 11
1580+
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
15811581

15821582

15831583
class Data2VecAudioOnnxConfig(AudioOnnxConfig):
1584-
NORMALIZED_CONFIG_CLASS = NormalizedConfig
1585-
ATOL_FOR_VALIDATION = 1e-4
15861584
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
1585+
NORMALIZED_CONFIG_CLASS = NormalizedConfig
15871586

15881587

15891588
class PerceiverDummyInputGenerator(DummyVisionInputGenerator):

0 commit comments

Comments
 (0)