Skip to content

Commit 32ceae1

Browse files
authored
fix when attention_mask=None (#1067)
* fix when attention_mask=None Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix position_ids Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tests for forward only with input_ids Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix input dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 847adbc commit 32ceae1

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

optimum/exporters/ipex/modeling_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _llama_model_forward(
180180
position_ids = torch.arange(
181181
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
182182
)
183-
position_ids = position_ids.unsqueeze(0)
183+
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
184184

185185
if inputs_embeds is None:
186186
inputs_embeds = self.embed_tokens(input_ids)
@@ -297,7 +297,7 @@ def _falcon_model_forward(
297297
)
298298

299299
if position_ids is None:
300-
position_ids = cache_position.unsqueeze(0)
300+
position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
301301

302302
# Prepare head mask if needed
303303
# 1.0 in head_mask indicate we keep the head
@@ -419,7 +419,7 @@ def _gpt2_model_forward(
419419
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
420420
if position_ids is None:
421421
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
422-
position_ids = position_ids.unsqueeze(0)
422+
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
423423

424424
if inputs_embeds is None:
425425
inputs_embeds = self.wte(input_ids)

optimum/intel/ipex/modeling_base.py

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ def forward(
276276
attention_mask: Optional[torch.FloatTensor] = None,
277277
**kwargs,
278278
) -> CausalLMOutputWithPast:
279+
if self.add_patch and input_ids is not None and attention_mask is None:
280+
attention_mask = torch.ones_like(input_ids)
279281
return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
280282

281283
def _prepare_generation_config(

tests/ipex/test_modeling.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def test_compare_to_transformers(self, model_arch):
241241
model_id = MODEL_NAMES[model_arch]
242242
set_seed(SEED)
243243
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
244-
# Test model forward do not need cache.
245244
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
246245
self.assertIsInstance(ipex_model.config, PretrainedConfig)
247246
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -275,6 +274,38 @@ def test_compare_to_transformers(self, model_arch):
275274
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
276275
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))
277276

277+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
278+
def test_forward(self, model_arch):
279+
model_id = MODEL_NAMES[model_arch]
280+
set_seed(SEED)
281+
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32
282+
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
283+
self.assertIsInstance(ipex_model.config, PretrainedConfig)
284+
input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long)
285+
outputs = ipex_model(input_ids)
286+
287+
self.assertIsInstance(outputs.logits, torch.Tensor)
288+
289+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
290+
with torch.no_grad():
291+
transformers_outputs = transformers_model(input_ids)
292+
293+
# Test re-load model
294+
with tempfile.TemporaryDirectory() as tmpdirname:
295+
ipex_model.save_pretrained(tmpdirname)
296+
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE)
297+
loaded_model_outputs = loaded_model(input_ids)
298+
299+
# Test init method
300+
init_model = self.IPEX_MODEL_CLASS(transformers_model)
301+
init_model_outputs = init_model(input_ids)
302+
303+
# Compare tensor outputs
304+
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
305+
# To avoid float pointing error
306+
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
307+
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))
308+
278309
@parameterized.expand(SUPPORTED_ARCHITECTURES)
279310
def test_pipeline(self, model_arch):
280311
dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32

0 commit comments

Comments
 (0)