Skip to content

Commit 4f79e05

Browse files
authored
Add transformers 4.49 support (huggingface#1172)
* transformers 4.49 * fix qwen2vl patcher * disable tests for models incompatibles with 4.49 * fix * fix * skip tests * disable test * disable test * udpate expected tests op quantized * fix * update quant op tests * add back test * fix pattern * style * disable * add minicpmv back
1 parent 0652389 commit 4f79e05

File tree

4 files changed

+63
-27
lines changed

4 files changed

+63
-27
lines changed

optimum/exporters/openvino/model_patcher.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -3935,14 +3935,28 @@ def __enter__(self):
39353935
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
39363936
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
39373937
def sdpa_attn_forward(
3938-
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None
3938+
self,
3939+
hidden_states: torch.Tensor,
3940+
attention_mask: torch.Tensor,
3941+
rotary_pos_emb: torch.Tensor = None,
3942+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
39393943
) -> torch.Tensor:
39403944
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
39413945

39423946
seq_length = hidden_states.shape[0]
39433947
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
3944-
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
3945-
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
3948+
3949+
if is_transformers_version(">=", "4.49"):
3950+
if position_embeddings is None:
3951+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
3952+
cos = emb.cos().float()
3953+
sin = emb.sin().float()
3954+
else:
3955+
cos, sin = position_embeddings
3956+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
3957+
else:
3958+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
3959+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
39463960

39473961
q = q.transpose(0, 1)
39483962
k = k.transpose(0, 1)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
INSTALL_REQUIRE = [
3030
"torch>=1.11",
3131
"optimum~=1.24",
32-
"transformers>=4.36,<4.49",
32+
"transformers>=4.36,<4.50",
3333
"datasets>=1.4.0",
3434
"sentencepiece",
3535
"setuptools",

tests/openvino/test_modeling.py

+18
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
10391039
@parameterized.expand(SUPPORTED_ARCHITECTURES)
10401040
def test_compare_to_transformers(self, model_arch):
10411041
model_id = MODEL_NAMES[model_arch]
1042+
1043+
# TODO: add back once dtype fixed everywhere
1044+
# https://huggingface.co/katuni4ka/tiny-random-chatglm2/blob/main/modeling_chatglm.py#L720
1045+
# https://huggingface.co/katuni4ka/tiny-random-chatglm2/blob/main/modeling_chatglm.py#L759
1046+
if model_arch in {"chatglm", "glm4"} and is_transformers_version(">=", "4.49"):
1047+
self.skipTest("Incompatible modeling code")
1048+
10421049
not_stateful = []
10431050
if is_openvino_version("<", "2024.0"):
10441051
not_stateful.append("mixtral")
@@ -1117,6 +1124,11 @@ def test_compare_to_transformers(self, model_arch):
11171124
)
11181125

11191126
ov_outputs = ov_model.generate(**tokens, generation_config=gen_config)
1127+
1128+
# TODO: add back once https://huggingface.co/katuni4ka/tiny-random-minicpm3/discussions/1 merged (for all models) as current mdoeling incompatible with transformers >= v4.49
1129+
if model_arch in {"minicpm", "minicpm3", "arctic", "deepseek"} and is_transformers_version(">=", "4.49"):
1130+
self.skipTest("Incompatible modeling code")
1131+
11201132
additional_inputs = {}
11211133
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
11221134
# align cache representation in torch model
@@ -2119,6 +2131,7 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
21192131
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
21202132
if is_transformers_version(">=", "4.45.0"):
21212133
SUPPORTED_ARCHITECTURES += ["minicpmv", "internvl2", "phi3_v", "qwen2_vl"]
2134+
21222135
if is_transformers_version(">=", "4.46.0"):
21232136
SUPPORTED_ARCHITECTURES += ["maira2"]
21242137

@@ -2220,6 +2233,11 @@ def test_compare_to_transformers(self, model_arch):
22202233
set_seed(SEED)
22212234
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
22222235
set_seed(SEED)
2236+
2237+
# TODO: add back once https://huggingface.co/katuni4ka/tiny-random-minicpm3/discussions/1 merged for all models as current mdoeling incompatible with transformers >= v4.49
2238+
if model_arch in {"phi3_v", "nanollava"} and is_transformers_version(">=", "4.49"):
2239+
self.skipTest("Incompatible modeling code")
2240+
22232241
with torch.no_grad():
22242242
transformers_outputs = transformers_model.generate(**transformers_inputs, generation_config=gen_config)
22252243

tests/openvino/test_quantization.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
"text-classification": ("glue", "sst2", "sentence"),
9292
}
9393

94+
pattern_prefix = "^__module.model.model" if is_transformers_version(">=", "4.49") else "^__module.model"
95+
9496

9597
class OVQuantizerTest(unittest.TestCase):
9698
SUPPORTED_ARCHITECTURES_TORCH_MODEL = (
@@ -158,12 +160,12 @@ class OVQuantizerTest(unittest.TestCase):
158160
dtype="nf4",
159161
group_size=16,
160162
ratio=0.5,
161-
ignored_scope={"patterns": ["^__module.model.layers.0.self_attn"]},
163+
ignored_scope={"patterns": [f"{pattern_prefix}.layers.0.self_attn"]},
162164
),
163165
full_quantization_config=OVQuantizationConfig(
164-
dtype="f8e4m3", ignored_scope={"patterns": ["^__module.model.layers.0.mlp"]}
166+
dtype="f8e4m3", ignored_scope={"patterns": [f"{pattern_prefix}.layers.0.mlp"]}
165167
),
166-
ignored_scope={"patterns": ["^__module.model.layers.1.self_attn"]},
168+
ignored_scope={"patterns": [f"{pattern_prefix}.layers.1.self_attn"]},
167169
dataset="wikitext2",
168170
num_samples=1,
169171
),
@@ -183,12 +185,12 @@ class OVQuantizerTest(unittest.TestCase):
183185
dtype="nf4",
184186
group_size=16,
185187
ratio=0.5,
186-
ignored_scope={"patterns": ["^__module.model.layers.0.self_attn"]},
188+
ignored_scope={"patterns": [f"{pattern_prefix}.layers.0.self_attn"]},
187189
),
188190
full_quantization_config=OVQuantizationConfig(
189-
dtype="f8e5m2", ignored_scope={"patterns": ["^__module.model.layers.0.mlp"]}
191+
dtype="f8e5m2", ignored_scope={"patterns": [f"{pattern_prefix}.layers.0.mlp"]}
190192
),
191-
ignored_scope={"patterns": ["^__module.model.layers.1.self_attn"]},
193+
ignored_scope={"patterns": [f"{pattern_prefix}.layers.1.self_attn"]},
192194
dataset="wikitext2",
193195
num_samples=1,
194196
),
@@ -435,7 +437,7 @@ class OVWeightCompressionTest(unittest.TestCase):
435437
sensitivity_metric="mean_activation_magnitude",
436438
dataset="c4",
437439
),
438-
[{"int8": 14, "int4": 25}],
440+
[{"int8": 18, "int4": 23}] if is_transformers_version(">=", "4.49") else [{"int8": 14, "int4": 25}],
439441
),
440442
(
441443
OVModelForCausalLM,
@@ -449,7 +451,7 @@ class OVWeightCompressionTest(unittest.TestCase):
449451
sensitivity_metric="mean_activation_magnitude",
450452
dataset=["one two, " * i for i in range(10)],
451453
),
452-
[{"int8": 16, "int4": 24}],
454+
[{"int8": 18, "int4": 23}] if is_transformers_version(">=", "4.49") else [{"int8": 16, "int4": 24}],
453455
),
454456
(
455457
OVModelForCausalLM,
@@ -612,21 +614,23 @@ class OVWeightCompressionTest(unittest.TestCase):
612614
),
613615
[{"int8": 8, "int4": 22}, {"int8": 1}, {"int8": 11}],
614616
),
615-
(
616-
OVModelForVisualCausalLM,
617-
"phi3_v",
618-
True,
619-
dict(
620-
bits=4,
621-
group_size=16,
622-
dataset="contextual",
623-
ratio=0.8,
624-
sensitivity_metric="mean_activation_magnitude",
625-
num_samples=1,
626-
trust_remote_code=True,
627-
),
628-
[{"int8": 4, "int4": 14}, {"int8": 1}, {"int8": 7}, {"int8": 2}],
629-
),
617+
# TODO: add back once https://huggingface.co/katuni4ka/tiny-random-phi3-vision/blob/main/processing_phi3_v.py#L313 modified to add chat_template
618+
# currently incompatible with transformers >= v4.49
619+
# (
620+
# OVModelForVisualCausalLM,
621+
# "phi3_v",
622+
# True,
623+
# dict(
624+
# bits=4,
625+
# group_size=16,
626+
# dataset="contextual",
627+
# ratio=0.8,
628+
# sensitivity_metric="mean_activation_magnitude",
629+
# num_samples=1,
630+
# trust_remote_code=True,
631+
# ),
632+
# [{"int8": 4, "int4": 14}, {"int8": 1}, {"int8": 7}, {"int8": 2}],
633+
# ),
630634
(
631635
OVModelForVisualCausalLM,
632636
"qwen2_vl",

0 commit comments

Comments
 (0)