Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.48 #1136

Merged
merged 12 commits into from
Feb 4, 2025
32 changes: 18 additions & 14 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,14 +718,15 @@ def _mistral_update_causal_mask(
class MistralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.42.0"):
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)

else:
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
if hasattr(layer.self_attn, "rotary_emb"):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand All @@ -734,7 +735,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

for layer in self._model.model.layers:
if hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward


Expand Down Expand Up @@ -1580,19 +1581,19 @@ def __enter__(self):
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

if is_transformers_version(">=", "4.42.0"):
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)

# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
# init inv_freq for torchscript tracing
for layer in self._model.model.layers:
if is_torch_version(">=", "2.1.0"):
if is_torch_version(">=", "2.1.0") and is_transformers_version("<", "4.48.0"):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

if layer.self_attn.rotary_emb.inv_freq is None:
if hasattr(layer.self_attn, "rotary_emb") and layer.self_attn.rotary_emb.inv_freq is None:
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
Expand Down Expand Up @@ -2493,7 +2494,9 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
if hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"):
if hasattr(self._model.model.layers[0].self_attn, "rotary_emb") and hasattr(
self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"
):
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

Expand Down Expand Up @@ -3045,15 +3048,16 @@ def patched_forward(self, fn):
def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
if is_transformers_version("<", "4.48"):
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES

sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"

for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@


if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
try:
from transformers.generation.streamers import BaseStreamer
except Exception:
from typing import Generator as BaseStreamer

from transformers.modeling_utils import PreTrainedModel


Expand Down
13 changes: 13 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,18 @@ def preprocess_inputs(
prompt = "<image>\n" + text
else:
prompt = text

if getattr(processor, "patch_size", None) is None:
if (
getattr(config, "vision_config", None) is not None
and getattr(config.vision_config, "patch_size", None) is not None
):
processor.patch_size = config.vision_config.patch_size
else:
raise ValueError(
"Processor does not have `patch_size` attribute. Please fix the processor or provide `patch_size` in the config."
)

inputs = processor(images=image, text=prompt, return_tensors="pt")
return inputs

Expand Down Expand Up @@ -1915,6 +1927,7 @@ def preprocess_inputs(
input_ids = tokenizer(text, return_tensors="pt").input_ids
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
result = {"input_ids": input_ids, "attention_mask": attention_mask}

if image is not None:
result["images"] = processor(images=[image], return_tensors="pt")["pixel_values"]
return result
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
INSTALL_REQUIRE = [
"torch>=1.11",
"optimum~=1.24",
"transformers>=4.36,<4.48",
"transformers>=4.36,<4.49",
"datasets>=1.4.0",
"sentencepiece",
"setuptools",
Expand Down
17 changes: 11 additions & 6 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_load_from_hub_and_save_visual_language_model(self):
else:
self.assertEqual(component.request.get_property("PERFORMANCE_HINT"), "LATENCY")

processor.patch_size = loaded_model.config.vision_config.patch_size
inputs = processor(images=image, text=prompt, return_tensors="pt")
set_seed(SEED)
loaded_model_outputs = loaded_model(**inputs)
Expand Down Expand Up @@ -2170,6 +2171,7 @@ def test_compare_to_transformers(self, model_arch):
for component_name, component in ov_model.components.items():
self.assertIsInstance(component, MODEL_PARTS_CLS_MAPPING[component_name])
self.assertIsInstance(ov_model.config, PretrainedConfig)

inputs = ov_model.preprocess_inputs(**preprocessors, text=prompt, image=self.IMAGE.resize((600, 600)))
transformers_inputs = copy.deepcopy(inputs)
test_device = "AUTO"
Expand Down Expand Up @@ -2235,6 +2237,7 @@ def test_llava_with_new_preprocessing(self, model_arch):
patch_size=config.vision_config.patch_size,
vision_feature_select_strategy=config.vision_feature_select_strategy,
trust_remote_code=model_arch in self.REMOTE_CODE_MODELS,
num_additional_image_tokens=1,
)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id)
ov_model = OVModelForVisualCausalLM.from_pretrained(
Expand All @@ -2244,8 +2247,9 @@ def test_llava_with_new_preprocessing(self, model_arch):
self.assertTrue(processor.patch_size is not None)
self.assertTrue(processor.vision_feature_select_strategy is not None)
inputs = processor(images=self.IMAGE, text=prompt, return_tensors="pt")
self.assertTrue(
(inputs.input_ids == ov_model.config.image_token_index).sum(1).max() >= ov_model.config.image_seq_length
self.assertGreaterEqual(
(inputs.input_ids == ov_model.config.image_token_index).sum().max().item(),
ov_model.config.image_seq_length,
)
set_seed(SEED)
with torch.no_grad():
Expand Down Expand Up @@ -2308,17 +2312,17 @@ def test_generate_utils(self, model_arch):

def get_preprocessors(self, model_arch):
model_id = MODEL_NAMES[model_arch]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)

if model_arch == "nanollava":
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
processor = AutoProcessor.from_pretrained(
config.mm_vision_tower, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
preprocessors = {"processor": processor, "tokenizer": tokenizer}
preprocessors = {"processor": processor, "tokenizer": tokenizer, "config": config}
elif model_arch == "internvl2":
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
Expand All @@ -2327,7 +2331,8 @@ def get_preprocessors(self, model_arch):
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
preprocessors = {"processor": processor, "tokenizer": None}
preprocessors = {"processor": processor, "tokenizer": None, "config": config}

return preprocessors

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down
Loading