@@ -4414,6 +4414,7 @@ def __init__(
4414
4414
model_kwargs : Dict [str , Any ],
4415
4415
):
4416
4416
model .__orig_forward = model .forward
4417
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4417
4418
model .forward = model .get_image_features
4418
4419
super ().__init__ (config , model , model_kwargs )
4419
4420
@@ -4422,23 +4423,94 @@ def __exit__(self, exc_type, exc_value, traceback):
4422
4423
self ._model .forward = self ._model .__orig_forward
4423
4424
4424
4425
4425
- class Gemma3LMModelPatcher (Gemma2ModelPatcher ):
4426
+ # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
4427
+ def _gemma3_mm_update_causal_mask (
4428
+ self , attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training : bool = False
4429
+ ):
4430
+ if attention_mask is not None and attention_mask .dim () == 4 :
4431
+ # In this case we assume that the mask comes already in inverted
4432
+ # form and requires no inversion or slicing.
4433
+ return attention_mask
4434
+
4435
+ min_dtype = torch .finfo (torch .float16 ).min
4436
+ inputs_lead_dim , sequence_length = input_tensor .shape [:2 ]
4437
+ target_length = (
4438
+ attention_mask .shape [- 1 ]
4439
+ if isinstance (attention_mask , torch .Tensor )
4440
+ else cache_position [0 ] + sequence_length + 1
4441
+ )
4442
+
4443
+ causal_mask = torch .full (
4444
+ (sequence_length , target_length ), fill_value = min_dtype , dtype = self .dtype , device = cache_position .device
4445
+ )
4446
+
4447
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
4448
+ if sequence_length != 1 :
4449
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
4450
+
4451
+ causal_mask *= torch .arange (target_length , device = cache_position .device ) > cache_position .reshape (- 1 , 1 )
4452
+ causal_mask = causal_mask [None , None , :, :].expand (inputs_lead_dim , 1 , - 1 , - 1 )
4453
+
4454
+ # Apply bidirectional mask on images if token type ids are provided
4455
+ if token_type_ids is not None and sequence_length != 1 :
4456
+ token_type_mask = token_type_ids .unsqueeze (1 ) == token_type_ids .unsqueeze (2 )
4457
+ token_type_mask [token_type_ids == 0 ] = False # if text token do not change anything
4458
+ token_type_mask = token_type_mask .unsqueeze (1 ).to (causal_mask .device , dtype = torch .bool )
4459
+ causal_mask = causal_mask .clone ()
4460
+ causal_mask [:, :, :, :sequence_length ] = causal_mask [:, :, :, :sequence_length ].masked_fill (
4461
+ token_type_mask , 0.0
4462
+ )
4463
+
4464
+ if attention_mask is not None :
4465
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
4466
+ mask_length = attention_mask .shape [- 1 ]
4467
+
4468
+ # Then apply padding mask (will mask pad tokens)
4469
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (causal_mask .device )
4470
+ padding_mask = padding_mask == 0
4471
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (padding_mask , min_dtype )
4472
+
4473
+ return causal_mask
4474
+
4475
+
4476
+ class Gemma3LMModelPatcher (DecoderModelPatcher ):
4426
4477
def __init__ (
4427
4478
self ,
4428
4479
config : "OnnxConfig" ,
4429
4480
model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
4430
4481
model_kwargs : Optional [Dict [str , Any ]] = None ,
4431
4482
):
4432
4483
model .__orig_forward = model .forward
4484
+ model ._update_causal_mask_mm = types .MethodType (_gemma3_mm_update_causal_mask , model )
4485
+
4486
+ # Difference from original:
4487
+ # uses Dynamic cache from legacy cache instead of HybridCache
4488
+ # calculate causal mask from multimodal
4489
+ def forward (self , attention_mask , position_ids , past_key_values , token_type_ids , inputs_embeds ):
4490
+ from transformers .cache_utils import DynamicCache
4491
+
4492
+ pkv = DynamicCache .from_legacy_cache (past_key_values )
4433
4493
4434
- def forward (self , attention_mask , position_ids , past_key_values , inputs_embeds ):
4435
- return self .__orig_forward (
4494
+ past_seen_tokens = past_key_values [0 ][0 ].shape [- 2 ]
4495
+ cache_position = torch .arange (
4496
+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
4497
+ )
4498
+
4499
+ causal_mask = self ._update_causal_mask_mm (
4500
+ attention_mask , token_type_ids , past_key_values , cache_position , inputs_embeds
4501
+ )
4502
+
4503
+ result = self .__orig_forward (
4436
4504
input_ids = None ,
4437
- attention_mask = attention_mask ,
4505
+ attention_mask = causal_mask ,
4438
4506
position_ids = position_ids ,
4439
- past_key_values = past_key_values ,
4507
+ cache_position = cache_position ,
4508
+ past_key_values = pkv ,
4440
4509
inputs_embeds = inputs_embeds ,
4441
4510
)
4511
+ upd_pkv = result ["past_key_values" ]
4512
+ result ["past_key_values" ] = upd_pkv .to_legacy_cache ()
4513
+ return result
4442
4514
4443
4515
model .forward = types .MethodType (forward , model )
4444
4516
super ().__init__ (config , model , model_kwargs )
0 commit comments