@@ -2803,7 +2803,6 @@ def patched_forward(*args, **kwargs):
2803
2803
2804
2804
signature = inspect .signature (self .orig_forward )
2805
2805
args , kwargs = override_arguments (args , kwargs , signature , model_kwargs = self .model_kwargs )
2806
-
2807
2806
return_legacy_cache = False
2808
2807
pkv_in_args = False
2809
2808
legacy_pkv = None
@@ -4407,7 +4406,7 @@ def __init__(
4407
4406
super ().__init__ (config , model , model_kwargs )
4408
4407
4409
4408
4410
- class GotOCR2ImageEmbeddingsModelPatcher (ModelPatcher ):
4409
+ class CommonImageEmbeddingsModelPatcher (ModelPatcher ):
4411
4410
def __init__ (
4412
4411
self ,
4413
4412
config : "OnnxConfig" ,
@@ -4416,9 +4415,107 @@ def __init__(
4416
4415
):
4417
4416
model .__orig_forward = model .forward
4418
4417
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4418
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4419
4419
model .forward = model .get_image_features
4420
4420
super ().__init__ (config , model , model_kwargs )
4421
4421
4422
4422
def __exit__ (self , exc_type , exc_value , traceback ):
4423
4423
super ().__exit__ (exc_type , exc_value , traceback )
4424
4424
self ._model .forward = self ._model .__orig_forward
4425
+
4426
+
4427
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
4428
+ def _gemma3_mm_update_causal_mask (
4429
+ self , attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training : bool = False
4430
+ ):
4431
+ if attention_mask is not None and attention_mask .dim () == 4 :
4432
+ # In this case we assume that the mask comes already in inverted
4433
+ # form and requires no inversion or slicing.
4434
+ return attention_mask
4435
+
4436
+ min_dtype = torch .finfo (torch .float16 ).min
4437
+ inputs_lead_dim , sequence_length = input_tensor .shape [:2 ]
4438
+ target_length = (
4439
+ attention_mask .shape [- 1 ]
4440
+ if isinstance (attention_mask , torch .Tensor )
4441
+ else cache_position [0 ] + sequence_length + 1
4442
+ )
4443
+
4444
+ causal_mask = torch .full (
4445
+ (sequence_length , target_length ), fill_value = min_dtype , dtype = self .dtype , device = cache_position .device
4446
+ )
4447
+
4448
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
4449
+ if sequence_length != 1 :
4450
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
4451
+
4452
+ causal_mask *= torch .arange (target_length , device = cache_position .device ) > cache_position .reshape (- 1 , 1 )
4453
+ causal_mask = causal_mask [None , None , :, :].expand (inputs_lead_dim , 1 , - 1 , - 1 )
4454
+
4455
+ # Apply bidirectional mask on images if token type ids are provided
4456
+ if token_type_ids is not None and sequence_length != 1 :
4457
+ token_type_mask = token_type_ids .unsqueeze (1 ) == token_type_ids .unsqueeze (2 )
4458
+ token_type_mask [token_type_ids == 0 ] = False # if text token do not change anything
4459
+ token_type_mask = token_type_mask .unsqueeze (1 ).to (causal_mask .device , dtype = torch .bool )
4460
+ causal_mask = causal_mask .clone ()
4461
+ causal_mask [:, :, :, :sequence_length ] = causal_mask [:, :, :, :sequence_length ].masked_fill (
4462
+ token_type_mask , 0.0
4463
+ )
4464
+
4465
+ if attention_mask is not None :
4466
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
4467
+ mask_length = attention_mask .shape [- 1 ]
4468
+
4469
+ # Then apply padding mask (will mask pad tokens)
4470
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (causal_mask .device )
4471
+ padding_mask = padding_mask == 0
4472
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (padding_mask , min_dtype )
4473
+
4474
+ return causal_mask
4475
+
4476
+
4477
+ class Gemma3LMModelPatcher (DecoderModelPatcher ):
4478
+ def __init__ (
4479
+ self ,
4480
+ config : "OnnxConfig" ,
4481
+ model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
4482
+ model_kwargs : Optional [Dict [str , Any ]] = None ,
4483
+ ):
4484
+ model .__orig_forward = model .forward
4485
+ model ._update_causal_mask_mm = types .MethodType (_gemma3_mm_update_causal_mask , model )
4486
+
4487
+ # Difference from original:
4488
+ # uses Dynamic cache from legacy cache instead of HybridCache
4489
+ # calculate causal mask from multimodal
4490
+ def forward (self , attention_mask , position_ids , past_key_values , token_type_ids , inputs_embeds ):
4491
+ from transformers .cache_utils import DynamicCache
4492
+
4493
+ pkv = DynamicCache .from_legacy_cache (past_key_values )
4494
+
4495
+ past_seen_tokens = past_key_values [0 ][0 ].shape [- 2 ]
4496
+ cache_position = torch .arange (
4497
+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
4498
+ )
4499
+
4500
+ causal_mask = self ._update_causal_mask_mm (
4501
+ attention_mask , token_type_ids , past_key_values , cache_position , inputs_embeds
4502
+ )
4503
+
4504
+ result = self .__orig_forward (
4505
+ input_ids = None ,
4506
+ attention_mask = causal_mask ,
4507
+ position_ids = position_ids ,
4508
+ cache_position = cache_position ,
4509
+ past_key_values = pkv ,
4510
+ inputs_embeds = inputs_embeds ,
4511
+ )
4512
+ upd_pkv = result ["past_key_values" ]
4513
+ result ["past_key_values" ] = upd_pkv .to_legacy_cache ()
4514
+ return result
4515
+
4516
+ model .forward = types .MethodType (forward , model )
4517
+ super ().__init__ (config , model , model_kwargs )
4518
+
4519
+ def __exit__ (self , exc_type , exc_value , traceback ):
4520
+ super ().__exit__ (exc_type , exc_value , traceback )
4521
+ self ._model .forward = self ._model .__orig_forward
0 commit comments