Skip to content

Commit 0f6e903

Browse files
committed
add support got-ocr2
1 parent 7f70f2b commit 0f6e903

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

optimum/exporters/openvino/model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
FalconModelPatcher,
8080
FluxTransfromerModelPatcher,
8181
Gemma2ModelPatcher,
82+
GotOCR2ImageEmbeddingsModelPatcher,
8283
GptBigCodeModelPatcher,
8384
GptJModelPatcher,
8485
GptNeoModelPatcher,
@@ -3001,3 +3002,16 @@ def patch_model_for_export(
30013002
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
30023003
) -> "ModelPatcher":
30033004
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
3005+
3006+
3007+
@register_in_tasks_manager("got-ocr2", *["image-text-to-text"], library_name="transformers")
3008+
class GotOCR2OpenVINOConfig(LlavaOpenVINOConfig):
3009+
MIN_TRANSFORMERS_VERSION = "4.49.0"
3010+
3011+
def patch_model_for_export(
3012+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
3013+
):
3014+
model_kwargs = model_kwargs or {}
3015+
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
3016+
return super().patch_model_for_export(model, model_kwargs)
3017+
return GotOCR2ImageEmbeddingsModelPatcher(self, model, model_kwargs)

optimum/exporters/openvino/model_patcher.py

+17
Original file line numberDiff line numberDiff line change
@@ -4405,3 +4405,20 @@ def __init__(
44054405
layer.mlp.down_proj.to(torch.float32)
44064406

44074407
super().__init__(config, model, model_kwargs)
4408+
4409+
4410+
class GotOCR2ImageEmbeddingsModelPatcher(ModelPatcher):
4411+
def __init__(
4412+
self,
4413+
config: "OnnxConfig",
4414+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
4415+
model_kwargs: Dict[str, Any],
4416+
):
4417+
model.__orig_forward = model.forward
4418+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4419+
model.forward = model.get_image_features
4420+
super().__init__(config, model, model_kwargs)
4421+
4422+
def __exit__(self, exc_type, exc_value, traceback):
4423+
super().__exit__(exc_type, exc_value, traceback)
4424+
self._model.forward = self._model.__orig_forward

optimum/exporters/openvino/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def get_submodels(model):
228228
"phi3-v",
229229
"qwen2-vl",
230230
"qwen2-5-vl",
231+
"got-ocr2",
231232
]
232233

233234

optimum/intel/openvino/modeling_visual_language.py

+45
Original file line numberDiff line numberDiff line change
@@ -3109,6 +3109,50 @@ def preprocess_inputs(
31093109
return processed_inputs
31103110

31113111

3112+
class _OVGotOCR2ForCausalLM(OVModelForVisualCausalLM):
3113+
def get_vision_embeddings(self, pixel_values, input_ids, **kwargs):
3114+
if input_ids is not None and input_ids.shape[1] == 1 and kwargs.get("past_key_values") is not None:
3115+
return None
3116+
return self.vision_embeddings(pixel_values).last_hidden_state
3117+
3118+
def merge_vision_text_embeddings(
3119+
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
3120+
):
3121+
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L836-L845
3122+
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
3123+
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
3124+
n_image_tokens = (input_ids == self.config.image_token_index).sum()
3125+
n_image_features = image_features.shape[0] * image_features.shape[1]
3126+
if n_image_tokens != n_image_features:
3127+
raise ValueError(
3128+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
3129+
)
3130+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
3131+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
3132+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
3133+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
3134+
3135+
return inputs_embeds, attention_mask, position_ids
3136+
3137+
@staticmethod
3138+
def preprocess_inputs(
3139+
text: Optional[str] = None,
3140+
image: Optional["Image"] = None,
3141+
processor: Optional[AutoImageProcessor] = None,
3142+
tokenizer: Optional[PreTrainedTokenizer] = None,
3143+
config: Optional[PretrainedConfig] = None,
3144+
video: Optional["VideoInput"] = None,
3145+
):
3146+
if processor is None:
3147+
raise ValueError("processor is required")
3148+
if video is not None:
3149+
raise ValueError("Video input is not supported")
3150+
if image is None:
3151+
raise ValueError("Image is required")
3152+
processed_inputs = processor(image, return_tensors="pt")
3153+
return processed_inputs
3154+
3155+
31123156
MODEL_TYPE_TO_CLS_MAPPING = {
31133157
"llava": _OVLlavaForCausalLM,
31143158
"llava_next": _OVLlavaNextForCausalLM,
@@ -3120,4 +3164,5 @@ def preprocess_inputs(
31203164
"internvl_chat": _OVInternVLForCausalLM,
31213165
"qwen2_vl": _OVQwen2VLForCausalLM,
31223166
"qwen2_5_vl": _OVQwen2_5_VLForCausalLM,
3167+
"got_ocr2": _OVGotOCR2ForCausalLM,
31233168
}

0 commit comments

Comments
 (0)