Skip to content

Commit b7703dc

Browse files
authoredOct 20, 2023
allow optionally fill attention mask in forward (#457)
1 parent 8273e7f commit b7703dc

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed
 

‎optimum/intel/openvino/modeling_decoder.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,10 @@ def forward(
352352
input_ids = input_ids[:, -1:]
353353

354354
inputs = {}
355+
past_len = 0
355356
if past_key_values is not None:
357+
seq_len_dim = 1 if self.model.input(self.key_value_input_names[0]).get_partial_shape()[1].is_dynamic else 2
358+
past_len = past_key_values[0][0].shape[seq_len_dim]
356359
if self._pkv_precision == Type.bf16:
357360
# numpy does not support bf16, pretending f16, should change to bf16
358361
past_key_values = tuple(
@@ -387,8 +390,13 @@ def forward(
387390
inputs["input_ids"] = np.array(input_ids)
388391

389392
# Add the attention_mask inputs when needed
390-
if "attention_mask" in self.input_names and attention_mask is not None:
391-
inputs["attention_mask"] = np.array(attention_mask)
393+
if "attention_mask" in self.input_names:
394+
if attention_mask is not None:
395+
inputs["attention_mask"] = np.array(attention_mask)
396+
else:
397+
inputs["attention_mask"] = np.ones(
398+
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
399+
)
392400

393401
# Run inference
394402
self.request.start_async(inputs, shared_memory=True)

‎tests/openvino/test_modeling.py

+23
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,29 @@ def test_auto_device_loading(self):
573573
del model
574574
gc.collect()
575575

576+
def test_default_filling_attention_mask(self):
577+
model_id = MODEL_NAMES["gpt2"]
578+
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
579+
tokenizer = AutoTokenizer.from_pretrained(model_id)
580+
tokenizer.pad_token = tokenizer.eos_token
581+
texts = ["this is a simple input"]
582+
tokens = tokenizer(texts, return_tensors="pt")
583+
self.assertTrue("attention_mask" in model_with_cache.input_names)
584+
outs = model_with_cache(**tokens)
585+
attention_mask = tokens.pop("attention_mask")
586+
outs_without_attn_mask = model_with_cache(**tokens)
587+
self.assertTrue(torch.allclose(outs.logits, outs_without_attn_mask.logits))
588+
input_ids = torch.argmax(outs.logits, dim=2)
589+
past_key_values = outs.past_key_values
590+
attention_mask = torch.ones((input_ids.shape[0], tokens.input_ids.shape[1] + 1), dtype=torch.long)
591+
outs_step2 = model_with_cache(
592+
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values
593+
)
594+
outs_without_attn_mask_step2 = model_with_cache(input_ids=input_ids, past_key_values=past_key_values)
595+
self.assertTrue(torch.allclose(outs_step2.logits, outs_without_attn_mask_step2.logits))
596+
del model_with_cache
597+
gc.collect()
598+
576599

577600
class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
578601
SUPPORTED_ARCHITECTURES = (

0 commit comments

Comments
 (0)
Please sign in to comment.