Skip to content

Commit 27f392e

Browse files
committed
add tests and docs
1 parent 7947b6a commit 27f392e

File tree

6 files changed

+145
-9
lines changed

6 files changed

+145
-9
lines changed

docs/source/openvino/models.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Here is the list of the supported architectures :
6262
- GPT-NeoX-Japanese
6363
- Gemma
6464
- Gemma2
65+
- Gemma3
6566
- Granite
6667
- GraniteMoE
6768
- Hubert

optimum/exporters/openvino/model_configs.py

+5
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
14221422
inputs_embed_shape
14231423
)
14241424
dummy_inputs["inputs_embeds"] = inputs_embeds
1425+
if "token_type_ids" in self.inputs:
1426+
dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[
1427+
0
1428+
].random_int_tensor(input_ids.shape, min_value=0, max_value=2)
14251429
return dummy_inputs
14261430

14271431

@@ -3058,5 +3062,6 @@ def with_behavior(
30583062
self.int_dtype,
30593063
self.float_dtype,
30603064
model_patcher=Gemma3LMModelPatcher,
3065+
inputs_update={"token_type_ids": {0: "batch_size", 1: "past_sequence_length + 1"}},
30613066
)
30623067
return super().with_behavior(behavior)

optimum/exporters/openvino/model_patcher.py

+77-5
Original file line numberDiff line numberDiff line change
@@ -4414,6 +4414,7 @@ def __init__(
44144414
model_kwargs: Dict[str, Any],
44154415
):
44164416
model.__orig_forward = model.forward
4417+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
44174418
model.forward = model.get_image_features
44184419
super().__init__(config, model, model_kwargs)
44194420

@@ -4422,23 +4423,94 @@ def __exit__(self, exc_type, exc_value, traceback):
44224423
self._model.forward = self._model.__orig_forward
44234424

44244425

4425-
class Gemma3LMModelPatcher(Gemma2ModelPatcher):
4426+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
4427+
def _gemma3_mm_update_causal_mask(
4428+
self, attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training: bool = False
4429+
):
4430+
if attention_mask is not None and attention_mask.dim() == 4:
4431+
# In this case we assume that the mask comes already in inverted
4432+
# form and requires no inversion or slicing.
4433+
return attention_mask
4434+
4435+
min_dtype = torch.finfo(torch.float16).min
4436+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
4437+
target_length = (
4438+
attention_mask.shape[-1]
4439+
if isinstance(attention_mask, torch.Tensor)
4440+
else cache_position[0] + sequence_length + 1
4441+
)
4442+
4443+
causal_mask = torch.full(
4444+
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
4445+
)
4446+
4447+
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
4448+
if sequence_length != 1:
4449+
causal_mask = torch.triu(causal_mask, diagonal=1)
4450+
4451+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
4452+
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
4453+
4454+
# Apply bidirectional mask on images if token type ids are provided
4455+
if token_type_ids is not None and sequence_length != 1:
4456+
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
4457+
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
4458+
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
4459+
causal_mask = causal_mask.clone()
4460+
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
4461+
token_type_mask, 0.0
4462+
)
4463+
4464+
if attention_mask is not None:
4465+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
4466+
mask_length = attention_mask.shape[-1]
4467+
4468+
# Then apply padding mask (will mask pad tokens)
4469+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
4470+
padding_mask = padding_mask == 0
4471+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
4472+
4473+
return causal_mask
4474+
4475+
4476+
class Gemma3LMModelPatcher(DecoderModelPatcher):
44264477
def __init__(
44274478
self,
44284479
config: "OnnxConfig",
44294480
model: Union["PreTrainedModel", "TFPreTrainedModel"],
44304481
model_kwargs: Optional[Dict[str, Any]] = None,
44314482
):
44324483
model.__orig_forward = model.forward
4484+
model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, model)
4485+
4486+
# Difference from original:
4487+
# uses Dynamic cache from legacy cache instead of HybridCache
4488+
# calculate causal mask from multimodal
4489+
def forward(self, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds):
4490+
from transformers.cache_utils import DynamicCache
4491+
4492+
pkv = DynamicCache.from_legacy_cache(past_key_values)
44334493

4434-
def forward(self, attention_mask, position_ids, past_key_values, inputs_embeds):
4435-
return self.__orig_forward(
4494+
past_seen_tokens = past_key_values[0][0].shape[-2]
4495+
cache_position = torch.arange(
4496+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
4497+
)
4498+
4499+
causal_mask = self._update_causal_mask_mm(
4500+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds
4501+
)
4502+
4503+
result = self.__orig_forward(
44364504
input_ids=None,
4437-
attention_mask=attention_mask,
4505+
attention_mask=causal_mask,
44384506
position_ids=position_ids,
4439-
past_key_values=past_key_values,
4507+
cache_position=cache_position,
4508+
past_key_values=pkv,
44404509
inputs_embeds=inputs_embeds,
44414510
)
4511+
upd_pkv = result["past_key_values"]
4512+
result["past_key_values"] = upd_pkv.to_legacy_cache()
4513+
return result
44424514

44434515
model.forward = types.MethodType(forward, model)
44444516
super().__init__(config, model, model_kwargs)

optimum/intel/openvino/modeling_visual_language.py

+33
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def prepare_inputs(
130130
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
131131
position_ids: Optional[torch.LongTensor] = None,
132132
inputs_embeds: Optional[torch.FloatTensor] = None,
133+
token_type_ids: Optional[torch.LongTensor] = None,
133134
**kwargs,
134135
):
135136
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
@@ -177,6 +178,11 @@ def prepare_inputs(
177178

178179
inputs["position_ids"] = position_ids
179180

181+
if "token_type_ids" in self.input_names:
182+
if token_type_ids is None:
183+
token_type_ids = np.zeros(inputs_embeds.shape[:2], dtype=int)
184+
inputs["token_type_ids"] = token_type_ids
185+
180186
if "beam_idx" in self.input_names:
181187
inputs["beam_idx"] = (
182188
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
@@ -736,6 +742,7 @@ def forward(
736742
inputs_embeds=inputs_embeds,
737743
attention_mask=attention_mask,
738744
position_ids=position_ids,
745+
token_type_ids=token_type_ids,
739746
past_key_values=past_key_values,
740747
**kwargs,
741748
)
@@ -804,6 +811,11 @@ def prepare_inputs_for_generation(
804811
if attention_mask is not None and position_ids is None:
805812
position_ids = attention_mask.long().cumsum(-1) - 1
806813
position_ids.masked_fill_(attention_mask == 0, 1)
814+
815+
# position_ids in Gemma3 are 1-indexed
816+
if self.config.model_type == "gemma3":
817+
position_ids += 1
818+
807819
if past_key_values is not None:
808820
position_ids = position_ids[:, -input_ids.shape[1] :]
809821

@@ -829,6 +841,7 @@ def prepare_inputs_for_generation(
829841
"pixel_values_videos": kwargs.get("pixel_values_videos"),
830842
"image_grid_thw": kwargs.get("image_grid_thw"),
831843
"video_grid_thw": kwargs.get("video_grid_thw"),
844+
"token_type_ids": kwargs.get("token_type_ids"),
832845
}
833846
)
834847
return model_inputs
@@ -3119,6 +3132,7 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
31193132
def merge_vision_text_embeddings(
31203133
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
31213134
):
3135+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1323-L1339
31223136
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
31233137
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
31243138
if input_ids is None:
@@ -3163,6 +3177,25 @@ def preprocess_inputs(
31633177
inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt")
31643178
return inputs
31653179

3180+
def _update_model_kwargs_for_generation(
3181+
self,
3182+
outputs: ModelOutput,
3183+
model_kwargs: Dict[str, Any],
3184+
is_encoder_decoder: bool = False,
3185+
num_new_tokens: int = 1,
3186+
) -> Dict[str, Any]:
3187+
model_kwargs = super()._update_model_kwargs_for_generation(
3188+
outputs=outputs,
3189+
model_kwargs=model_kwargs,
3190+
is_encoder_decoder=is_encoder_decoder,
3191+
num_new_tokens=num_new_tokens,
3192+
)
3193+
3194+
# Token type ids used only for first inference mask generation
3195+
model_kwargs.pop("token_type_ids", None)
3196+
3197+
return model_kwargs
3198+
31663199

31673200
MODEL_TYPE_TO_CLS_MAPPING = {
31683201
"llava": _OVLlavaForCausalLM,

tests/openvino/test_modeling.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
10131013
if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows":
10141014
SUPPORTED_ARCHITECTURES += ("mixtral_awq",)
10151015

1016+
if is_transformers_version(">", "4.49"):
1017+
SUPPORTED_ARCHITECTURES += ("gemma3-text",)
1018+
10161019
GENERATION_LENGTH = 100
10171020
REMOTE_CODE_MODELS = (
10181021
"chatglm",
@@ -1112,7 +1115,7 @@ def test_compare_to_transformers(self, model_arch):
11121115
gen_config = GenerationConfig(
11131116
max_new_tokens=30,
11141117
min_new_tokens=30,
1115-
num_beams=3,
1118+
num_beams=2,
11161119
do_sample=False,
11171120
eos_token_id=None,
11181121
)
@@ -1126,7 +1129,7 @@ def test_compare_to_transformers(self, model_arch):
11261129
additional_inputs = {}
11271130
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
11281131
# align cache representation in torch model
1129-
if model_arch == "gemma2":
1132+
if model_arch in ["gemma2", "gemma3-text"]:
11301133
patch_update_causal_mask(transformers_model, "4.43.0")
11311134
transformers_model._supports_cache_class = True
11321135
from transformers.cache_utils import DynamicCache
@@ -2143,6 +2146,8 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
21432146
if is_transformers_version(">=", "4.49.0"):
21442147
SUPPORTED_ARCHITECTURES += ["qwen2_5_vl"]
21452148
SUPPORT_VIDEO.append("qwen2_5_vl")
2149+
if is_transformers_version(">", "4.49"):
2150+
SUPPORTED_ARCHITECTURES += ["gemma3"]
21462151
TASK = "image-text-to-text"
21472152
REMOTE_CODE_MODELS = ["internvl2", "minicpmv", "nanollava", "phi3_v", "maira2"]
21482153

@@ -2154,7 +2159,13 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
21542159
)
21552160

21562161
def get_transformer_model_class(self, model_arch):
2157-
if is_transformers_version(">=", "4.46") and model_arch in ["llava", "llava_next", "qwen2_vl", "qwen2_5_vl"]:
2162+
if is_transformers_version(">=", "4.46") and model_arch in [
2163+
"llava",
2164+
"llava_next",
2165+
"qwen2_vl",
2166+
"qwen2_5_vl",
2167+
"gemma3",
2168+
]:
21582169
from transformers import AutoModelForImageTextToText
21592170

21602171
return AutoModelForImageTextToText
@@ -2250,8 +2261,20 @@ def test_compare_to_transformers(self, model_arch):
22502261
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
22512262
set_seed(SEED)
22522263

2264+
additional_inputs = {}
2265+
# gemma3 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
2266+
# align cache representation in torch model
2267+
if model_arch == "gemma3":
2268+
patch_update_causal_mask(transformers_model, "4.43.0")
2269+
transformers_model._supports_cache_class = True
2270+
from transformers.cache_utils import DynamicCache
2271+
2272+
additional_inputs = {"past_key_values": DynamicCache()}
2273+
22532274
with torch.no_grad():
2254-
transformers_outputs = transformers_model.generate(**transformers_inputs, generation_config=gen_config)
2275+
transformers_outputs = transformers_model.generate(
2276+
**transformers_inputs, generation_config=gen_config, **additional_inputs
2277+
)
22552278

22562279
# original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them
22572280
if model_arch in ["minicpmv", "internvl2"]:

tests/openvino/utils_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
"exaone": "katuni4ka/tiny-random-exaone",
6464
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
6565
"gemma2": "katuni4ka/tiny-random-gemma2",
66+
"gemma3-text": "katuni4ka/tiny-random-gemma3-text",
67+
"gemma3": "katuni4ka/tiny-random-gemma3",
6668
"falcon": "fxmarty/really-tiny-falcon-testing",
6769
"falcon-40b": "katuni4ka/tiny-random-falcon-40b",
6870
"flaubert": "hf-internal-testing/tiny-random-flaubert",

0 commit comments

Comments
 (0)