Skip to content

Commit f1984f4

Browse files
authored
Add llama test model to cover MQA (#585)
* change llama test model to cover MQA * keep llama and llama2 in tests * fix code style
1 parent 5e319aa commit f1984f4

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

tests/generation/test_modeling.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
3232
"mistral": "echarlaix/tiny-random-mistral",
3333
"llama": "fxmarty/tiny-llama-fast-tokenizer",
34+
"llama2": "Jiqing/tiny_random_llama2",
3435
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
3536
}
3637

@@ -54,6 +55,7 @@ class ModelingIntegrationTest(unittest.TestCase):
5455
"gpt_neo",
5556
"mistral",
5657
"llama",
58+
"llama2",
5759
# "gpt_bigcode",
5860
)
5961

tests/ipex/test_inference.py

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
4343
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
4444
"llama": "fxmarty/tiny-llama-fast-tokenizer",
45+
"llama2": "Jiqing/tiny_random_llama2",
4546
"opt": "hf-internal-testing/tiny-random-OPTModel",
4647
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
4748
}
@@ -66,6 +67,7 @@ class IPEXIntegrationTest(unittest.TestCase):
6667
"gpt_neo",
6768
# "gpt_bigcode",
6869
"llama",
70+
"llama2",
6971
"opt",
7072
"mpt",
7173
)

tests/ipex/test_modeling.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
6868
"levit": "hf-internal-testing/tiny-random-LevitModel",
6969
"llama": "fxmarty/tiny-llama-fast-tokenizer",
70+
"llama2": "Jiqing/tiny_random_llama2",
7071
"marian": "sshleifer/tiny-marian-en-de",
7172
"mbart": "hf-internal-testing/tiny-random-mbart",
7273
"mistral": "echarlaix/tiny-random-mistral",
@@ -209,6 +210,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
209210
"gpt_neo",
210211
"gpt_neox",
211212
"llama",
213+
"llama2",
212214
"mistral",
213215
# "phi",
214216
"mpt",
@@ -226,7 +228,9 @@ def test_compare_to_transformers(self, model_arch):
226228
self.assertTrue(ipex_model.use_cache)
227229
tokenizer = AutoTokenizer.from_pretrained(model_id)
228230
tokens = tokenizer(
229-
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
231+
"This is a sample",
232+
return_tensors="pt",
233+
return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
230234
)
231235
position_ids = None
232236
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:

0 commit comments

Comments
 (0)