Skip to content

Commit 8c60c7a

Browse files
Update tests/ipex/test_modeling.py
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 070a0dc commit 8c60c7a

File tree

1 file changed

+13
-52
lines changed

1 file changed

+13
-52
lines changed

tests/ipex/test_modeling.py

+13-52
Original file line numberDiff line numberDiff line change
@@ -277,65 +277,26 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
277277
set_seed(SEED)
278278
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache)
279279
self.assertEqual(model.use_cache, use_cache)
280+
trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id)
280281
tokenizer = AutoTokenizer.from_pretrained(model_id)
281282
tokenizer.pad_token = tokenizer.eos_token
282283
# Test with batch_size is 1 and 2.
283284
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
285+
generation_configs = (
286+
GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=True),
287+
GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True),
288+
GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True),
289+
GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True),
290+
GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6),
291+
GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0),
292+
)
284293
for text in texts:
285-
for num_beams in [2, 4, 8, 32]:
286-
tokens = tokenizer(text, padding=True, return_tensors="pt")
287-
generation_config = GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True)
294+
tokens = tokenizer(text, padding=True, return_tensors="pt")
295+
for generation_config in generation_configs:
288296
outputs = model.generate(**tokens, generation_config=generation_config)
297+
transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config)
289298
self.assertIsInstance(outputs, torch.Tensor)
290-
291-
@parameterized.expand(
292-
grid_parameters(
293-
{
294-
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
295-
"use_cache": [True, False],
296-
}
297-
)
298-
)
299-
@unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
300-
def test_ipex_patching_top_k(self, test_name, model_arch, use_cache):
301-
model_id = MODEL_NAMES[model_arch]
302-
set_seed(SEED)
303-
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache)
304-
self.assertEqual(model.use_cache, use_cache)
305-
tokenizer = AutoTokenizer.from_pretrained(model_id)
306-
tokenizer.pad_token = tokenizer.eos_token
307-
# Test with batch_size is 1 and 2.
308-
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
309-
for text in texts:
310-
tokens = tokenizer(text, padding=True, return_tensors="pt")
311-
generation_config = GenerationConfig(max_new_tokens=4, do_sample=True, top_p=1.0, top_k=5)
312-
outputs = model.generate(**tokens, generation_config=generation_config)
313-
self.assertIsInstance(outputs, torch.Tensor)
314-
315-
@parameterized.expand(
316-
grid_parameters(
317-
{
318-
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
319-
"use_cache": [True, False],
320-
}
321-
)
322-
)
323-
@unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
324-
def test_ipex_patching_top_p(self, test_name, model_arch, use_cache):
325-
model_id = MODEL_NAMES[model_arch]
326-
set_seed(SEED)
327-
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache)
328-
self.assertEqual(model.use_cache, use_cache)
329-
tokenizer = AutoTokenizer.from_pretrained(model_id)
330-
tokenizer.pad_token = tokenizer.eos_token
331-
# Test with batch_size is 1 and 2.
332-
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
333-
for text in texts:
334-
tokens = tokenizer(text, padding=True, return_tensors="pt")
335-
generation_config = GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0)
336-
outputs = model.generate(**tokens, generation_config=generation_config)
337-
self.assertIsInstance(outputs, torch.Tensor)
338-
299+
self.assertEqual(outputs, transformers_outputs)
339300
def test_compare_with_and_without_past_key_values(self):
340301
model_id = "echarlaix/tiny-random-gpt2-torchscript"
341302
tokenizer = AutoTokenizer.from_pretrained(model_id)

0 commit comments

Comments
 (0)