Skip to content

Commit 7947b6a

Browse files
committed
add gemma3 support
1 parent 7f70f2b commit 7947b6a

File tree

4 files changed

+158
-1
lines changed

4 files changed

+158
-1
lines changed

optimum/exporters/openvino/model_configs.py

+59
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
FalconModelPatcher,
8080
FluxTransfromerModelPatcher,
8181
Gemma2ModelPatcher,
82+
Gemma3ImageEmbeddingsModelPatcher,
83+
Gemma3LMModelPatcher,
8284
GptBigCodeModelPatcher,
8385
GptJModelPatcher,
8486
GptNeoModelPatcher,
@@ -142,6 +144,10 @@ def init_model_configs():
142144
"transformers",
143145
"AutoModelForVision2Seq",
144146
)
147+
TasksManager._CUSTOM_CLASSES[("pt", "gemma3", "image-text-to-text")] = (
148+
"transformers",
149+
"Gemma3ForConditionalGeneration",
150+
)
145151

146152
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
147153
"image-text-to-text"
@@ -1140,6 +1146,21 @@ def patch_model_for_export(
11401146
return Gemma2ModelPatcher(self, model, model_kwargs=model_kwargs)
11411147

11421148

1149+
@register_in_tasks_manager(
1150+
"gemma3-text",
1151+
*[
1152+
"feature-extraction",
1153+
"feature-extraction-with-past",
1154+
"text-generation",
1155+
"text-generation-with-past",
1156+
"text-classification",
1157+
],
1158+
library_name="transformers",
1159+
)
1160+
class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig):
1161+
MIN_TRANSFORMERS_VERSION = version.parse("4.49.0")
1162+
1163+
11431164
class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
11441165
def __init__(
11451166
self,
@@ -3001,3 +3022,41 @@ def patch_model_for_export(
30013022
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
30023023
) -> "ModelPatcher":
30033024
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
3025+
3026+
3027+
@register_in_tasks_manager("gemma3", *["image-text-to-text"], library_name="transformers")
3028+
class Gemma3OpneVINOConfig(LlavaOpenVINOConfig):
3029+
MIN_TRANSFORMERS_VERSION = "4.49.0"
3030+
3031+
def patch_model_for_export(
3032+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
3033+
):
3034+
model_kwargs = model_kwargs or {}
3035+
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
3036+
return super().patch_model_for_export(model, model_kwargs)
3037+
return Gemma3ImageEmbeddingsModelPatcher(self, model, model_kwargs)
3038+
3039+
def with_behavior(
3040+
self,
3041+
behavior: Union[str, LlavaConfigBehavior],
3042+
):
3043+
"""
3044+
Creates a config for different behaviour.
3045+
3046+
Args:
3047+
behavior ([`ConfigBehavior`]):
3048+
The behavior to use for the new instance.
3049+
"""
3050+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
3051+
behavior = LlavaConfigBehavior(behavior)
3052+
3053+
if behavior == LlavaConfigBehavior.LANGUAGE:
3054+
model_type = self._orig_config.text_config.model_type
3055+
return get_vlm_text_generation_config(
3056+
model_type,
3057+
self._orig_config.text_config,
3058+
self.int_dtype,
3059+
self.float_dtype,
3060+
model_patcher=Gemma3LMModelPatcher,
3061+
)
3062+
return super().with_behavior(behavior)

optimum/exporters/openvino/model_patcher.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -2803,7 +2803,6 @@ def patched_forward(*args, **kwargs):
28032803

28042804
signature = inspect.signature(self.orig_forward)
28052805
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)
2806-
28072806
return_legacy_cache = False
28082807
pkv_in_args = False
28092808
legacy_pkv = None
@@ -4405,3 +4404,45 @@ def __init__(
44054404
layer.mlp.down_proj.to(torch.float32)
44064405

44074406
super().__init__(config, model, model_kwargs)
4407+
4408+
4409+
class Gemma3ImageEmbeddingsModelPatcher(ModelPatcher):
4410+
def __init__(
4411+
self,
4412+
config: "OnnxConfig",
4413+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
4414+
model_kwargs: Dict[str, Any],
4415+
):
4416+
model.__orig_forward = model.forward
4417+
model.forward = model.get_image_features
4418+
super().__init__(config, model, model_kwargs)
4419+
4420+
def __exit__(self, exc_type, exc_value, traceback):
4421+
super().__exit__(exc_type, exc_value, traceback)
4422+
self._model.forward = self._model.__orig_forward
4423+
4424+
4425+
class Gemma3LMModelPatcher(Gemma2ModelPatcher):
4426+
def __init__(
4427+
self,
4428+
config: "OnnxConfig",
4429+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
4430+
model_kwargs: Optional[Dict[str, Any]] = None,
4431+
):
4432+
model.__orig_forward = model.forward
4433+
4434+
def forward(self, attention_mask, position_ids, past_key_values, inputs_embeds):
4435+
return self.__orig_forward(
4436+
input_ids=None,
4437+
attention_mask=attention_mask,
4438+
position_ids=position_ids,
4439+
past_key_values=past_key_values,
4440+
inputs_embeds=inputs_embeds,
4441+
)
4442+
4443+
model.forward = types.MethodType(forward, model)
4444+
super().__init__(config, model, model_kwargs)
4445+
4446+
def __exit__(self, exc_type, exc_value, traceback):
4447+
super().__exit__(exc_type, exc_value, traceback)
4448+
self._model.forward = self._model.__orig_forward

optimum/exporters/openvino/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def get_submodels(model):
228228
"phi3-v",
229229
"qwen2-vl",
230230
"qwen2-5-vl",
231+
"gemma3",
231232
]
232233

233234

optimum/intel/openvino/modeling_visual_language.py

+56
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def forward(
711711
rope_deltas=None,
712712
images=None,
713713
second_per_grid_ts=None,
714+
token_type_ids=None,
714715
**kwargs,
715716
):
716717
pixel_values = pixel_values if pixel_values is not None else images
@@ -3109,6 +3110,60 @@ def preprocess_inputs(
31093110
return processed_inputs
31103111

31113112

3113+
class _OVGemma3ForCausalLM(OVModelForVisualCausalLM):
3114+
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
3115+
if input_ids is not None and input_ids.shape[1] == 1:
3116+
return None
3117+
return self.vision_embeddings(pixel_values).last_hidden_state
3118+
3119+
def merge_vision_text_embeddings(
3120+
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
3121+
):
3122+
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
3123+
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
3124+
if input_ids is None:
3125+
special_image_mask = inputs_embeds == torch.from_numpy(
3126+
self.get_text_embeddings(torch.tensor([[self.config.image_token_index]], dtype=torch.long))[0]
3127+
)
3128+
else:
3129+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
3130+
special_image_mask = special_image_mask.expand_as(inputs_embeds)
3131+
3132+
image_features = image_features.to(inputs_embeds.dtype)
3133+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
3134+
3135+
return inputs_embeds, attention_mask, position_ids
3136+
3137+
@staticmethod
3138+
def preprocess_inputs(
3139+
text: str,
3140+
image: Optional["Image"] = None,
3141+
processor: Optional[AutoImageProcessor] = None,
3142+
tokenizer: Optional[PreTrainedTokenizer] = None,
3143+
config: Optional[PretrainedConfig] = None,
3144+
video: Optional["VideoInput"] = None,
3145+
):
3146+
if processor is None:
3147+
raise ValueError("Processor is required.")
3148+
if video is not None:
3149+
raise ValueError("Video input is not supported")
3150+
conversation = [
3151+
{
3152+
"role": "user",
3153+
"content": [
3154+
{"type": "text", "text": text},
3155+
],
3156+
}
3157+
]
3158+
if image is not None:
3159+
conversation[0]["content"].insert(0, {"type": "image"})
3160+
3161+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
3162+
3163+
inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt")
3164+
return inputs
3165+
3166+
31123167
MODEL_TYPE_TO_CLS_MAPPING = {
31133168
"llava": _OVLlavaForCausalLM,
31143169
"llava_next": _OVLlavaNextForCausalLM,
@@ -3120,4 +3175,5 @@ def preprocess_inputs(
31203175
"internvl_chat": _OVInternVLForCausalLM,
31213176
"qwen2_vl": _OVQwen2VLForCausalLM,
31223177
"qwen2_5_vl": _OVQwen2_5_VLForCausalLM,
3178+
"gemma3": _OVGemma3ForCausalLM,
31233179
}

0 commit comments

Comments
 (0)