@@ -573,6 +573,29 @@ def test_auto_device_loading(self):
573
573
del model
574
574
gc .collect ()
575
575
576
+ def test_default_filling_attention_mask (self ):
577
+ model_id = MODEL_NAMES ["gpt2" ]
578
+ model_with_cache = OVModelForCausalLM .from_pretrained (model_id , export = True , use_cache = True )
579
+ tokenizer = AutoTokenizer .from_pretrained (model_id )
580
+ tokenizer .pad_token = tokenizer .eos_token
581
+ texts = ["this is a simple input" ]
582
+ tokens = tokenizer (texts , return_tensors = "pt" )
583
+ self .assertTrue ("attention_mask" in model_with_cache .input_names )
584
+ outs = model_with_cache (** tokens )
585
+ attention_mask = tokens .pop ("attention_mask" )
586
+ outs_without_attn_mask = model_with_cache (** tokens )
587
+ self .assertTrue (torch .allclose (outs .logits , outs_without_attn_mask .logits ))
588
+ input_ids = torch .argmax (outs .logits , dim = 2 )
589
+ past_key_values = outs .past_key_values
590
+ attention_mask = torch .ones ((input_ids .shape [0 ], tokens .input_ids .shape [1 ] + 1 ), dtype = torch .long )
591
+ outs_step2 = model_with_cache (
592
+ input_ids = input_ids , attention_mask = attention_mask , past_key_values = past_key_values
593
+ )
594
+ outs_without_attn_mask_step2 = model_with_cache (input_ids = input_ids , past_key_values = past_key_values )
595
+ self .assertTrue (torch .allclose (outs_step2 .logits , outs_without_attn_mask_step2 .logits ))
596
+ del model_with_cache
597
+ gc .collect ()
598
+
576
599
577
600
class OVModelForMaskedLMIntegrationTest (unittest .TestCase ):
578
601
SUPPORTED_ARCHITECTURES = (
0 commit comments