Skip to content

Commit e03259c

Browse files
committed
test ipex patching
1 parent 3a86c40 commit e03259c

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

tests/ipex/test_modeling.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
IPEXModelForSequenceClassification,
4444
IPEXModelForTokenClassification,
4545
)
46+
from optimum.intel.utils.import_utils import is_ipex_version
47+
from optimum.utils.testing_utils import grid_parameters
4648

4749

4850
SEED = 42
@@ -215,6 +217,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
215217
"mpt",
216218
"opt",
217219
)
220+
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",)
218221
GENERATION_LENGTH = 100
219222
SPEEDUP_CACHE = 1.0
220223

@@ -256,19 +259,36 @@ def test_pipeline(self, model_arch):
256259
self.assertEqual(pipe.device, model.device)
257260
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
258261

259-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
260-
def test_multiple_inputs(self, model_arch):
262+
@parameterized.expand(
263+
grid_parameters(
264+
{
265+
"model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
266+
"use_cache": [True, False],
267+
"num_beams": [1, 4],
268+
"batch_size": [1, 4],
269+
}
270+
)
271+
)
272+
@unittest.skipIf(is_ipex_version("<=", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
273+
def test_ipex_patching(self, test_name, model_arch, use_cache, num_beams, batch_size):
261274
model_id = MODEL_NAMES[model_arch]
262275
set_seed(SEED)
263276
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True)
277+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
264278
tokenizer = AutoTokenizer.from_pretrained(model_id)
265279
tokenizer.pad_token = tokenizer.eos_token
266-
texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"]
280+
texts = ["This is a sample"] * batch_size
267281
tokens = tokenizer(texts, padding=True, return_tensors="pt")
268-
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=0, max_new_tokens=20, num_beams=2)
282+
generation_config = GenerationConfig(
283+
max_new_tokens=16, num_beams=num_beams, do_sample=False, use_cache=use_cache
284+
)
269285
outputs = model.generate(**tokens, generation_config=generation_config)
270-
self.assertIsInstance(outputs, torch.Tensor)
271-
self.assertEqual(outputs.shape[0], 3)
286+
with torch.no_grad():
287+
transformers_outputs = transformers_model(**tokens)
288+
289+
self.assertIsInstance(outputs.logits, torch.Tensor)
290+
# Compare tensor outputs
291+
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
272292

273293
def test_compare_with_and_without_past_key_values(self):
274294
model_id = "echarlaix/tiny-random-gpt2-torchscript"

0 commit comments

Comments
 (0)