Skip to content

Commit 55a59e3

Browse files
committed
add assisted decoding
1 parent 3a966c5 commit 55a59e3

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/ipex/test_modeling.py

+11
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,17 @@ def test_pipeline(self, model_arch):
275275
self.assertEqual(pipe.device, model.device)
276276
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
277277

278+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
279+
def test_assisted_decoding(self, model_arch):
280+
model_id = MODEL_NAMES[model_arch]
281+
tokenizer = AutoTokenizer.from_pretrained(model_id)
282+
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
283+
assistant_model = AutoModelForCausalLM.from_pretrained(model_id)
284+
tokens = tokenizer("This is a sample input", return_tensors="pt")
285+
output = model.generate(**tokens, do_sample=False)
286+
output_assisted = model.generate(**tokens, do_sample=False, assistant_model=assistant_model)
287+
self.assertTrue(torch.equal(output, output_assisted))
288+
278289
def test_compare_with_and_without_past_key_values(self):
279290
model_id = "echarlaix/tiny-random-gpt2-torchscript"
280291
tokenizer = AutoTokenizer.from_pretrained(model_id)

0 commit comments

Comments
 (0)