Skip to content

Commit 7055210

Browse files
authored
add gemma3 support (#1198)
* add gemma3 support * add tests and docs * add tests against main * temporary change test dataset for hybrid quant * Apply suggestions from code review * Update tests/openvino/test_exporters_cli.py
1 parent 5ac3544 commit 7055210

File tree

8 files changed

+286
-9
lines changed

8 files changed

+286
-9
lines changed

.github/workflows/test_openvino_slow.yml

+8-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ jobs:
3131
os: "ubuntu-22.04"
3232
- transformers-version: "4.45.0"
3333
os: "ubuntu-22.04"
34+
- transformers-version: "main"
35+
os: "ubuntu-22.04"
3436

3537
runs-on: ${{ matrix.os }}
3638

@@ -50,14 +52,18 @@ jobs:
5052
pip install .[openvino,tests] transformers[testing]
5153
pip uninstall -y nncf
5254
53-
- if: ${{ matrix.transformers-version != 'latest' }}
55+
- if: ${{ matrix.transformers-version != 'latest' && matrix.transformers-version != 'main' }}
5456
name: Install specific dependencies and versions required for older transformers
5557
run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator
5658

57-
- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }}
59+
- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' || matrix.transformers-version == 'main' }}
5860
name: Install auto-gptq, autoawq
5961
run: |
6062
pip install auto-gptq "autoawq<0.2.8" --extra-index-url https://download.pytorch.org/whl/cpu
63+
64+
- if: ${{ matrix.transformers-version == 'main' }}
65+
name: Install transformers from repository
66+
run: pip install git+https://github.com/huggingface/transformers.git
6167

6268
- name: Pip freeze
6369
run: pip freeze

docs/source/openvino/models.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Here is the list of the supported architectures :
6262
- GPT-NeoX-Japanese
6363
- Gemma
6464
- Gemma2
65+
- Gemma3
6566
- GOT-OCR 2.0
6667
- Granite
6768
- GraniteMoE

optimum/exporters/openvino/model_configs.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@
7373
BaichuanModelPatcher,
7474
ChatGLMModelPatcher,
7575
CodeGenModelPatcher,
76+
CommonImageEmbeddingsModelPatcher,
7677
DBRXModelPatcher,
7778
DeciLMModelPatcher,
7879
DeepseekPatcher,
7980
FalconModelPatcher,
8081
FluxTransfromerModelPatcher,
8182
Gemma2ModelPatcher,
82-
GotOCR2ImageEmbeddingsModelPatcher,
83+
Gemma3LMModelPatcher,
8384
GptBigCodeModelPatcher,
8485
GptJModelPatcher,
8586
GptNeoModelPatcher,
@@ -143,6 +144,10 @@ def init_model_configs():
143144
"transformers",
144145
"AutoModelForVision2Seq",
145146
)
147+
TasksManager._CUSTOM_CLASSES[("pt", "gemma3", "image-text-to-text")] = (
148+
"transformers",
149+
"Gemma3ForConditionalGeneration",
150+
)
146151

147152
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
148153
"image-text-to-text"
@@ -1141,6 +1146,21 @@ def patch_model_for_export(
11411146
return Gemma2ModelPatcher(self, model, model_kwargs=model_kwargs)
11421147

11431148

1149+
@register_in_tasks_manager(
1150+
"gemma3-text",
1151+
*[
1152+
"feature-extraction",
1153+
"feature-extraction-with-past",
1154+
"text-generation",
1155+
"text-generation-with-past",
1156+
"text-classification",
1157+
],
1158+
library_name="transformers",
1159+
)
1160+
class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig):
1161+
MIN_TRANSFORMERS_VERSION = version.parse("4.50.0")
1162+
1163+
11441164
class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
11451165
def __init__(
11461166
self,
@@ -1402,6 +1422,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
14021422
inputs_embed_shape
14031423
)
14041424
dummy_inputs["inputs_embeds"] = inputs_embeds
1425+
if "token_type_ids" in self.inputs:
1426+
dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[
1427+
0
1428+
].random_int_tensor(input_ids.shape, min_value=0, max_value=2)
14051429
return dummy_inputs
14061430

14071431

@@ -3014,4 +3038,43 @@ def patch_model_for_export(
30143038
model_kwargs = model_kwargs or {}
30153039
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
30163040
return super().patch_model_for_export(model, model_kwargs)
3017-
return GotOCR2ImageEmbeddingsModelPatcher(self, model, model_kwargs)
3041+
return CommonImageEmbeddingsModelPatcher(self, model, model_kwargs)
3042+
3043+
3044+
@register_in_tasks_manager("gemma3", *["image-text-to-text"], library_name="transformers")
3045+
class Gemma3OpenVINOConfig(LlavaOpenVINOConfig):
3046+
MIN_TRANSFORMERS_VERSION = "4.50.0"
3047+
3048+
def patch_model_for_export(
3049+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
3050+
):
3051+
model_kwargs = model_kwargs or {}
3052+
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
3053+
return super().patch_model_for_export(model, model_kwargs)
3054+
return CommonImageEmbeddingsModelPatcher(self, model, model_kwargs)
3055+
3056+
def with_behavior(
3057+
self,
3058+
behavior: Union[str, LlavaConfigBehavior],
3059+
):
3060+
"""
3061+
Creates a config for different behaviour.
3062+
3063+
Args:
3064+
behavior ([`ConfigBehavior`]):
3065+
The behavior to use for the new instance.
3066+
"""
3067+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
3068+
behavior = LlavaConfigBehavior(behavior)
3069+
3070+
if behavior == LlavaConfigBehavior.LANGUAGE:
3071+
model_type = self._orig_config.text_config.model_type
3072+
return get_vlm_text_generation_config(
3073+
model_type,
3074+
self._orig_config.text_config,
3075+
self.int_dtype,
3076+
self.float_dtype,
3077+
model_patcher=Gemma3LMModelPatcher,
3078+
inputs_update={"token_type_ids": {0: "batch_size", 1: "sequence_length"}},
3079+
)
3080+
return super().with_behavior(behavior)

optimum/exporters/openvino/model_patcher.py

+99-2
Original file line numberDiff line numberDiff line change
@@ -2803,7 +2803,6 @@ def patched_forward(*args, **kwargs):
28032803

28042804
signature = inspect.signature(self.orig_forward)
28052805
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
2806-
28072806
return_legacy_cache = False
28082807
pkv_in_args = False
28092808
legacy_pkv = None
@@ -4407,7 +4406,7 @@ def __init__(
44074406
super().__init__(config, model, model_kwargs)
44084407

44094408

4410-
class GotOCR2ImageEmbeddingsModelPatcher(ModelPatcher):
4409+
class CommonImageEmbeddingsModelPatcher(ModelPatcher):
44114410
def __init__(
44124411
self,
44134412
config: "OnnxConfig",
@@ -4416,9 +4415,107 @@ def __init__(
44164415
):
44174416
model.__orig_forward = model.forward
44184417
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4418+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
44194419
model.forward = model.get_image_features
44204420
super().__init__(config, model, model_kwargs)
44214421

44224422
def __exit__(self, exc_type, exc_value, traceback):
44234423
super().__exit__(exc_type, exc_value, traceback)
44244424
self._model.forward = self._model.__orig_forward
4425+
4426+
4427+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
4428+
def _gemma3_mm_update_causal_mask(
4429+
self, attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training: bool = False
4430+
):
4431+
if attention_mask is not None and attention_mask.dim() == 4:
4432+
# In this case we assume that the mask comes already in inverted
4433+
# form and requires no inversion or slicing.
4434+
return attention_mask
4435+
4436+
min_dtype = torch.finfo(torch.float16).min
4437+
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
4438+
target_length = (
4439+
attention_mask.shape[-1]
4440+
if isinstance(attention_mask, torch.Tensor)
4441+
else cache_position[0] + sequence_length + 1
4442+
)
4443+
4444+
causal_mask = torch.full(
4445+
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
4446+
)
4447+
4448+
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
4449+
if sequence_length != 1:
4450+
causal_mask = torch.triu(causal_mask, diagonal=1)
4451+
4452+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
4453+
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
4454+
4455+
# Apply bidirectional mask on images if token type ids are provided
4456+
if token_type_ids is not None and sequence_length != 1:
4457+
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
4458+
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
4459+
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
4460+
causal_mask = causal_mask.clone()
4461+
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
4462+
token_type_mask, 0.0
4463+
)
4464+
4465+
if attention_mask is not None:
4466+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
4467+
mask_length = attention_mask.shape[-1]
4468+
4469+
# Then apply padding mask (will mask pad tokens)
4470+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
4471+
padding_mask = padding_mask == 0
4472+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
4473+
4474+
return causal_mask
4475+
4476+
4477+
class Gemma3LMModelPatcher(DecoderModelPatcher):
4478+
def __init__(
4479+
self,
4480+
config: "OnnxConfig",
4481+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
4482+
model_kwargs: Optional[Dict[str, Any]] = None,
4483+
):
4484+
model.__orig_forward = model.forward
4485+
model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, model)
4486+
4487+
# Difference from original:
4488+
# uses Dynamic cache from legacy cache instead of HybridCache
4489+
# calculate causal mask from multimodal
4490+
def forward(self, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds):
4491+
from transformers.cache_utils import DynamicCache
4492+
4493+
pkv = DynamicCache.from_legacy_cache(past_key_values)
4494+
4495+
past_seen_tokens = past_key_values[0][0].shape[-2]
4496+
cache_position = torch.arange(
4497+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
4498+
)
4499+
4500+
causal_mask = self._update_causal_mask_mm(
4501+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds
4502+
)
4503+
4504+
result = self.__orig_forward(
4505+
input_ids=None,
4506+
attention_mask=causal_mask,
4507+
position_ids=position_ids,
4508+
cache_position=cache_position,
4509+
past_key_values=pkv,
4510+
inputs_embeds=inputs_embeds,
4511+
)
4512+
upd_pkv = result["past_key_values"]
4513+
result["past_key_values"] = upd_pkv.to_legacy_cache()
4514+
return result
4515+
4516+
model.forward = types.MethodType(forward, model)
4517+
super().__init__(config, model, model_kwargs)
4518+
4519+
def __exit__(self, exc_type, exc_value, traceback):
4520+
super().__exit__(exc_type, exc_value, traceback)
4521+
self._model.forward = self._model.__orig_forward

optimum/exporters/openvino/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def get_submodels(model):
229229
"qwen2-vl",
230230
"qwen2-5-vl",
231231
"got-ocr2",
232+
"gemma3",
232233
]
233234

234235

0 commit comments

Comments
 (0)