Skip to content

Commit 9813f90

Browse files
authored
Remove position_ids generation in IPEXModel forward (#566)
* fix jit model * rm autocast in model * support assisted decoding and add reorder cache function * add comment for _prepare_past_key_values * rebase main * fix model_dtype * rm useless comments * fix class name * revert _call_model * fix model_dtype warning liog * testiong low precision ipex model * add assisted decoding * remove low-precision testing as CI node does not support bf16 * fix conflict * remove prepare position_ids in forward * fix code style
1 parent 45dab01 commit 9813f90

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

optimum/intel/ipex/modeling_base.py

-6
Original file line numberDiff line numberDiff line change
@@ -506,12 +506,6 @@ def forward(
506506
"attention_mask": attention_mask,
507507
}
508508

509-
if "position_ids" in self.input_names and position_ids is None:
510-
position_ids = attention_mask.long().cumsum(-1) - 1
511-
position_ids.masked_fill_(attention_mask == 0, 1)
512-
if past_key_values:
513-
position_ids = position_ids[:, -1].unsqueeze(-1)
514-
515509
if "position_ids" in self.input_names or not self.input_names:
516510
inputs["position_ids"] = position_ids
517511

tests/ipex/test_modeling.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
set_seed,
3333
)
3434

35-
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
3635
from optimum.intel import (
3736
IPEXModel,
3837
IPEXModelForAudioClassification,
@@ -236,11 +235,8 @@ def test_compare_to_transformers(self, model_arch):
236235
return_tensors="pt",
237236
return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
238237
)
239-
position_ids = None
240-
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
241-
input_shape = tokens["input_ids"].shape
242-
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
243-
outputs = ipex_model(**tokens, position_ids=position_ids)
238+
inputs = ipex_model.prepare_inputs_for_generation(**tokens)
239+
outputs = ipex_model(**inputs)
244240

245241
self.assertIsInstance(outputs.logits, torch.Tensor)
246242
self.assertIsInstance(outputs.past_key_values, (tuple, list))
@@ -263,6 +259,22 @@ def test_pipeline(self, model_arch):
263259
self.assertEqual(pipe.device, model.device)
264260
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
265261

262+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
263+
def test_assisted_decoding(self, model_arch):
264+
model_id = MODEL_NAMES[model_arch]
265+
tokenizer = AutoTokenizer.from_pretrained(model_id)
266+
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
267+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
268+
tokens = tokenizer("This is a sample input", return_tensors="pt")
269+
ipex_output = ipex_model.generate(**tokens, do_sample=False)
270+
ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model)
271+
transformers_output = transformers_model.generate(**tokens, do_sample=False)
272+
transformers_output_assisted = transformers_model.generate(
273+
**tokens, do_sample=False, assistant_model=ipex_model
274+
)
275+
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted))
276+
self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))
277+
266278
@parameterized.expand(
267279
grid_parameters(
268280
{

0 commit comments

Comments
 (0)