|
43 | 43 | IPEXModelForSequenceClassification,
|
44 | 44 | IPEXModelForTokenClassification,
|
45 | 45 | )
|
| 46 | +from optimum.intel.utils.import_utils import is_ipex_version |
| 47 | +from optimum.utils.testing_utils import grid_parameters |
46 | 48 |
|
47 | 49 |
|
48 | 50 | SEED = 42
|
@@ -215,6 +217,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
|
215 | 217 | "mpt",
|
216 | 218 | "opt",
|
217 | 219 | )
|
| 220 | + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",) |
218 | 221 | GENERATION_LENGTH = 100
|
219 | 222 | SPEEDUP_CACHE = 1.0
|
220 | 223 |
|
@@ -256,19 +259,36 @@ def test_pipeline(self, model_arch):
|
256 | 259 | self.assertEqual(pipe.device, model.device)
|
257 | 260 | self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
|
258 | 261 |
|
259 |
| - @parameterized.expand(SUPPORTED_ARCHITECTURES) |
260 |
| - def test_multiple_inputs(self, model_arch): |
| 262 | + @parameterized.expand( |
| 263 | + grid_parameters( |
| 264 | + { |
| 265 | + "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES, |
| 266 | + "use_cache": [True, False], |
| 267 | + "num_beams": [1, 4], |
| 268 | + "batch_size": [1, 4], |
| 269 | + } |
| 270 | + ) |
| 271 | + ) |
| 272 | + @unittest.skipIf(is_ipex_version("<=", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") |
| 273 | + def test_ipex_patching(self, test_name, model_arch, use_cache, num_beams, batch_size): |
261 | 274 | model_id = MODEL_NAMES[model_arch]
|
262 | 275 | set_seed(SEED)
|
263 | 276 | model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
|
| 277 | + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) |
264 | 278 | tokenizer = AutoTokenizer.from_pretrained(model_id)
|
265 | 279 | tokenizer.pad_token = tokenizer.eos_token
|
266 |
| - texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] |
| 280 | + texts = ["This is a sample"] * batch_size |
267 | 281 | tokens = tokenizer(texts, padding=True, return_tensors="pt")
|
268 |
| - generation_config = GenerationConfig(encoder_no_repeat_ngram_size=0, max_new_tokens=20, num_beams=2) |
| 282 | + generation_config = GenerationConfig( |
| 283 | + max_new_tokens=16, num_beams=num_beams, do_sample=False, use_cache=use_cache |
| 284 | + ) |
269 | 285 | outputs = model.generate(**tokens, generation_config=generation_config)
|
270 |
| - self.assertIsInstance(outputs, torch.Tensor) |
271 |
| - self.assertEqual(outputs.shape[0], 3) |
| 286 | + with torch.no_grad(): |
| 287 | + transformers_outputs = transformers_model(**tokens) |
| 288 | + |
| 289 | + self.assertIsInstance(outputs.logits, torch.Tensor) |
| 290 | + # Compare tensor outputs |
| 291 | + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) |
272 | 292 |
|
273 | 293 | def test_compare_with_and_without_past_key_values(self):
|
274 | 294 | model_id = "echarlaix/tiny-random-gpt2-torchscript"
|
|
0 commit comments