Skip to content

Commit c17b879

Browse files
authored
add IPEX tests (#267)
1 parent 67f24c1 commit c17b879

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed

.github/workflows/test_ipex.yml

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
name: Intel IPEX - Test
4+
5+
on:
6+
push:
7+
branches: [ main ]
8+
pull_request:
9+
branches: [ main ]
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
13+
cancel-in-progress: true
14+
15+
jobs:
16+
build:
17+
strategy:
18+
fail-fast: false
19+
matrix:
20+
python-version: [3.8, 3.9]
21+
os: [ubuntu-latest]
22+
23+
runs-on: ${{ matrix.os }}
24+
steps:
25+
- uses: actions/checkout@v2
26+
- name: Setup Python ${{ matrix.python-version }}
27+
uses: actions/setup-python@v2
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
- name: Install dependencies
31+
run: |
32+
python -m pip install --upgrade pip
33+
pip install .[ipex,tests]
34+
pip install torch==1.13.0 intel-extension-for-pytorch==1.13.0
35+
- name: Test with Pytest
36+
run: |
37+
pytest tests/ipex/

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"optimum>=1.7.3",
1616
"transformers>=4.20.0",
1717
"datasets>=1.4.0",
18-
"torch",
1918
"sentencepiece",
2019
"scipy",
2120
]

tests/ipex/test_inference.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import tempfile
17+
import unittest
18+
19+
import torch
20+
21+
# TODO : add more tasks
22+
from transformers import (
23+
AutoFeatureExtractor,
24+
AutoModel,
25+
AutoModelForAudioClassification,
26+
AutoModelForCausalLM,
27+
AutoModelForImageClassification,
28+
AutoModelForMaskedLM,
29+
AutoModelForQuestionAnswering,
30+
AutoModelForSeq2SeqLM,
31+
AutoModelForSequenceClassification,
32+
AutoModelForTokenClassification,
33+
AutoTokenizer,
34+
pipeline,
35+
)
36+
37+
from evaluate import evaluator
38+
from optimum.intel import inference_mode as ipex_inference_mode
39+
from parameterized import parameterized
40+
41+
42+
MODEL_NAMES = {
43+
"bert": "hf-internal-testing/tiny-random-bert",
44+
"distilbert": "hf-internal-testing/tiny-random-distilbert",
45+
"roberta": "hf-internal-testing/tiny-random-roberta",
46+
}
47+
48+
_TASK_TO_AUTOMODELS = {
49+
"text-classification": AutoModelForSequenceClassification,
50+
"token-classification": AutoModelForTokenClassification,
51+
}
52+
53+
54+
class IPEXIntegrationTest(unittest.TestCase):
55+
SUPPORTED_ARCHITECTURES = (
56+
"bert",
57+
"distilbert",
58+
"roberta",
59+
)
60+
61+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
62+
def test_pipeline_classification_inference(self, model_arch):
63+
model_id = MODEL_NAMES[model_arch]
64+
tokenizer = AutoTokenizer.from_pretrained(model_id)
65+
inputs = "This is a sample input"
66+
for task, auto_model_class in _TASK_TO_AUTOMODELS.items():
67+
model = auto_model_class.from_pretrained(model_id)
68+
pipe = pipeline(task, model=model, tokenizer=tokenizer)
69+
70+
with torch.inference_mode():
71+
outputs = pipe(inputs)
72+
with ipex_inference_mode(pipe) as ipex_pipe:
73+
outputs_ipex = ipex_pipe(inputs)
74+
75+
self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"])

0 commit comments

Comments
 (0)