Skip to content

Commit c743886

Browse files
authored
IPEX test refactorization (#711)
1 parent 02d5e4e commit c743886

File tree

4 files changed

+96
-137
lines changed

4 files changed

+96
-137
lines changed

tests/ipex/test_inference.py

+37-50
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import torch
1818
from parameterized import parameterized
19-
20-
# TODO : add more tasks
2119
from transformers import (
2220
AutoModelForCausalLM,
2321
AutoModelForQuestionAnswering,
@@ -26,60 +24,51 @@
2624
AutoTokenizer,
2725
pipeline,
2826
)
27+
from utils_tests import MODEL_NAMES
2928

3029
from optimum.intel import inference_mode as ipex_inference_mode
3130
from optimum.intel.ipex.modeling_base import IPEXModel
3231

3332

34-
MODEL_NAMES = {
35-
"bert": "hf-internal-testing/tiny-random-bert",
36-
"bloom": "hf-internal-testing/tiny-random-BloomModel",
37-
"distilbert": "hf-internal-testing/tiny-random-distilbert",
38-
"roberta": "hf-internal-testing/tiny-random-roberta",
39-
"gptj": "hf-internal-testing/tiny-random-gptj",
40-
"gpt2": "hf-internal-testing/tiny-random-gpt2",
41-
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
42-
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
43-
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
44-
"llama": "fxmarty/tiny-llama-fast-tokenizer",
45-
"llama2": "Jiqing/tiny_random_llama2",
46-
"opt": "hf-internal-testing/tiny-random-OPTModel",
47-
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
48-
}
49-
5033
_CLASSIFICATION_TASK_TO_AUTOMODELS = {
5134
"text-classification": AutoModelForSequenceClassification,
5235
"token-classification": AutoModelForTokenClassification,
5336
}
5437

5538

56-
class IPEXIntegrationTest(unittest.TestCase):
57-
CLASSIFICATION_SUPPORTED_ARCHITECTURES = (
39+
class IPEXClassificationTest(unittest.TestCase):
40+
SUPPORTED_ARCHITECTURES = (
5841
"bert",
5942
"distilbert",
6043
"roberta",
6144
)
6245

63-
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = (
64-
"bloom",
65-
"gptj",
66-
"gpt2",
67-
"gpt_neo",
68-
"gpt_bigcode",
69-
"llama",
70-
"llama2",
71-
"opt",
72-
"mpt",
73-
)
46+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
47+
def test_pipeline_inference(self, model_arch):
48+
model_id = MODEL_NAMES[model_arch]
49+
tokenizer = AutoTokenizer.from_pretrained(model_id)
50+
inputs = "This is a sample input"
51+
for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items():
52+
model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32)
53+
pipe = pipeline(task, model=model, tokenizer=tokenizer)
7454

75-
QA_SUPPORTED_ARCHITECTURES = (
55+
with torch.inference_mode():
56+
outputs = pipe(inputs)
57+
with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe:
58+
outputs_ipex = ipex_pipe(inputs)
59+
self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule))
60+
self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"])
61+
62+
63+
class IPEXQuestionAnsweringTest(unittest.TestCase):
64+
SUPPORTED_ARCHITECTURES = (
7665
"bert",
7766
"distilbert",
7867
"roberta",
7968
)
8069

81-
@parameterized.expand(QA_SUPPORTED_ARCHITECTURES)
82-
def test_question_answering_pipeline_inference(self, model_arch):
70+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
71+
def test_pipeline_inference(self, model_arch):
8372
model_id = MODEL_NAMES[model_arch]
8473
tokenizer = AutoTokenizer.from_pretrained(model_id)
8574
model = AutoModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.float32)
@@ -95,24 +84,22 @@ def test_question_answering_pipeline_inference(self, model_arch):
9584
self.assertEqual(outputs["start"], outputs_ipex["start"])
9685
self.assertEqual(outputs["end"], outputs_ipex["end"])
9786

98-
@parameterized.expand(CLASSIFICATION_SUPPORTED_ARCHITECTURES)
99-
def test_classification_pipeline_inference(self, model_arch):
100-
model_id = MODEL_NAMES[model_arch]
101-
tokenizer = AutoTokenizer.from_pretrained(model_id)
102-
inputs = "This is a sample input"
103-
for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items():
104-
model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32)
105-
pipe = pipeline(task, model=model, tokenizer=tokenizer)
10687

107-
with torch.inference_mode():
108-
outputs = pipe(inputs)
109-
with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe:
110-
outputs_ipex = ipex_pipe(inputs)
111-
self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule))
112-
self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"])
88+
class IPEXTextGenerationTest(unittest.TestCase):
89+
SUPPORTED_ARCHITECTURES = (
90+
"bloom",
91+
"gptj",
92+
"gpt2",
93+
"gpt_neo",
94+
"gpt_bigcode",
95+
"llama",
96+
"llama2",
97+
"opt",
98+
"mpt",
99+
)
113100

114-
@parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES)
115-
def test_text_generation_pipeline_inference(self, model_arch):
101+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
102+
def test_pipeline_inference(self, model_arch):
116103
model_id = MODEL_NAMES[model_arch]
117104
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False)
118105
model = model.eval()

tests/ipex/test_modeling.py

+1-43
Original file line numberDiff line numberDiff line change
@@ -45,53 +45,11 @@
4545
)
4646
from optimum.intel.utils.import_utils import is_ipex_version
4747
from optimum.utils.testing_utils import grid_parameters
48+
from utils_tests import MODEL_NAMES
4849

4950

5051
SEED = 42
5152

52-
MODEL_NAMES = {
53-
"albert": "hf-internal-testing/tiny-random-albert",
54-
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
55-
"bert": "hf-internal-testing/tiny-random-bert",
56-
"bart": "hf-internal-testing/tiny-random-bart",
57-
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
58-
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
59-
"bloom": "hf-internal-testing/tiny-random-BloomModel",
60-
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
61-
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
62-
"convnext": "hf-internal-testing/tiny-random-convnext",
63-
"distilbert": "hf-internal-testing/tiny-random-distilbert",
64-
"electra": "hf-internal-testing/tiny-random-electra",
65-
"flaubert": "hf-internal-testing/tiny-random-flaubert",
66-
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
67-
"gpt2": "hf-internal-testing/tiny-random-gpt2",
68-
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
69-
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
70-
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
71-
"levit": "hf-internal-testing/tiny-random-LevitModel",
72-
"llama": "fxmarty/tiny-llama-fast-tokenizer",
73-
"llama2": "Jiqing/tiny_random_llama2",
74-
"marian": "sshleifer/tiny-marian-en-de",
75-
"mbart": "hf-internal-testing/tiny-random-mbart",
76-
"mistral": "echarlaix/tiny-random-mistral",
77-
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
78-
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
79-
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
80-
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
81-
"mt5": "stas/mt5-tiny-random",
82-
"opt": "hf-internal-testing/tiny-random-OPTModel",
83-
"phi": "echarlaix/tiny-random-PhiForCausalLM",
84-
"resnet": "hf-internal-testing/tiny-random-resnet",
85-
"roberta": "hf-internal-testing/tiny-random-roberta",
86-
"roformer": "hf-internal-testing/tiny-random-roformer",
87-
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
88-
"t5": "hf-internal-testing/tiny-random-t5",
89-
"unispeech": "hf-internal-testing/tiny-random-unispeech",
90-
"vit": "hf-internal-testing/tiny-random-vit",
91-
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
92-
"xlm": "hf-internal-testing/tiny-random-xlm",
93-
}
94-
9553

9654
class Timer(object):
9755
def __enter__(self):

tests/ipex/test_pipelines.py

+1-44
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from parameterized import parameterized
2121
from transformers import AutoTokenizer
2222
from transformers.pipelines import pipeline as transformers_pipeline
23+
from utils_tests import MODEL_NAMES
2324

2425
from optimum.intel.ipex.modeling_base import (
2526
IPEXModelForAudioClassification,
@@ -33,50 +34,6 @@
3334
from optimum.intel.pipelines import pipeline as ipex_pipeline
3435

3536

36-
MODEL_NAMES = {
37-
"albert": "hf-internal-testing/tiny-random-albert",
38-
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
39-
"bert": "hf-internal-testing/tiny-random-bert",
40-
"bart": "hf-internal-testing/tiny-random-bart",
41-
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
42-
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
43-
"bloom": "hf-internal-testing/tiny-random-BloomModel",
44-
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
45-
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
46-
"convnext": "hf-internal-testing/tiny-random-convnext",
47-
"distilbert": "hf-internal-testing/tiny-random-distilbert",
48-
"electra": "hf-internal-testing/tiny-random-electra",
49-
"flaubert": "hf-internal-testing/tiny-random-flaubert",
50-
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
51-
"gpt2": "hf-internal-testing/tiny-random-gpt2",
52-
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
53-
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
54-
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
55-
"levit": "hf-internal-testing/tiny-random-LevitModel",
56-
"llama": "fxmarty/tiny-llama-fast-tokenizer",
57-
"llama2": "Jiqing/tiny_random_llama2",
58-
"marian": "sshleifer/tiny-marian-en-de",
59-
"mbart": "hf-internal-testing/tiny-random-mbart",
60-
"mistral": "echarlaix/tiny-random-mistral",
61-
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
62-
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
63-
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
64-
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
65-
"mt5": "stas/mt5-tiny-random",
66-
"opt": "hf-internal-testing/tiny-random-OPTModel",
67-
"phi": "echarlaix/tiny-random-PhiForCausalLM",
68-
"resnet": "hf-internal-testing/tiny-random-resnet",
69-
"roberta": "hf-internal-testing/tiny-random-roberta",
70-
"roformer": "hf-internal-testing/tiny-random-roformer",
71-
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
72-
"t5": "hf-internal-testing/tiny-random-t5",
73-
"unispeech": "hf-internal-testing/tiny-random-unispeech",
74-
"vit": "hf-internal-testing/tiny-random-vit",
75-
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
76-
"xlm": "hf-internal-testing/tiny-random-xlm",
77-
}
78-
79-
8037
class PipelinesIntegrationTest(unittest.TestCase):
8138
COMMON_SUPPORTED_ARCHITECTURES = (
8239
"albert",

tests/ipex/utils_tests.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
MODEL_NAMES = {
17+
"albert": "hf-internal-testing/tiny-random-albert",
18+
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
19+
"bert": "hf-internal-testing/tiny-random-bert",
20+
"bart": "hf-internal-testing/tiny-random-bart",
21+
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
22+
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
23+
"bloom": "hf-internal-testing/tiny-random-BloomModel",
24+
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
25+
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
26+
"convnext": "hf-internal-testing/tiny-random-convnext",
27+
"distilbert": "hf-internal-testing/tiny-random-distilbert",
28+
"electra": "hf-internal-testing/tiny-random-electra",
29+
"flaubert": "hf-internal-testing/tiny-random-flaubert",
30+
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
31+
"gpt2": "hf-internal-testing/tiny-random-gpt2",
32+
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
33+
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
34+
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
35+
"levit": "hf-internal-testing/tiny-random-LevitModel",
36+
"llama": "fxmarty/tiny-llama-fast-tokenizer",
37+
"llama2": "Jiqing/tiny_random_llama2",
38+
"marian": "sshleifer/tiny-marian-en-de",
39+
"mbart": "hf-internal-testing/tiny-random-mbart",
40+
"mistral": "echarlaix/tiny-random-mistral",
41+
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
42+
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
43+
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
44+
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
45+
"mt5": "stas/mt5-tiny-random",
46+
"opt": "hf-internal-testing/tiny-random-OPTModel",
47+
"phi": "echarlaix/tiny-random-PhiForCausalLM",
48+
"resnet": "hf-internal-testing/tiny-random-resnet",
49+
"roberta": "hf-internal-testing/tiny-random-roberta",
50+
"roformer": "hf-internal-testing/tiny-random-roformer",
51+
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
52+
"t5": "hf-internal-testing/tiny-random-t5",
53+
"unispeech": "hf-internal-testing/tiny-random-unispeech",
54+
"vit": "hf-internal-testing/tiny-random-vit",
55+
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
56+
"xlm": "hf-internal-testing/tiny-random-xlm",
57+
}

0 commit comments

Comments
 (0)