Skip to content

Commit 8e9617f

Browse files
committed
Add task
1 parent 777db05 commit 8e9617f

File tree

6 files changed

+66
-48
lines changed

6 files changed

+66
-48
lines changed

optimum/intel/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@
4848
"IPEXModelForMaskedLM",
4949
"IPEXModelForTokenClassification",
5050
"IPEXModelForQuestionAnswering",
51+
"IPEXModelForImageClassification",
52+
"IPEXModelForAudioClassification",
5153
"IPEXModel",
5254
]
5355

54-
5556
try:
5657
if not (is_openvino_available() and is_nncf_available()):
5758
raise OptionalDependencyNotAvailable()
@@ -161,7 +162,9 @@
161162
else:
162163
from .ipex import (
163164
IPEXModel,
165+
IPEXModelForAudioClassification,
164166
IPEXModelForCausalLM,
167+
IPEXModelForImageClassification,
165168
IPEXModelForMaskedLM,
166169
IPEXModelForQuestionAnswering,
167170
IPEXModelForSequenceClassification,

optimum/intel/generation/modeling.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,12 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
6666

6767
def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
6868
model_inputs = prepare_jit_inputs(model, task, use_cache)
69-
model.config.return_dict = False
69+
model.config.return_dict = "text-generation" not in task
70+
7071
# check if the model_inputs is correct.
7172
model(**model_inputs)
7273

7374
torch._C._jit_set_texpr_fuser_enabled(False)
74-
if "past_key_values" in model_inputs.keys():
75-
model.config.return_dict = False
7675
if is_torch_version(">=", "2.1.0"):
7776
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
7877
else:

optimum/intel/ipex/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from optimum.intel.ipex.modeling_base import (
22
IPEXModel,
3+
IPEXModelForAudioClassification,
34
IPEXModelForCausalLM,
5+
IPEXModelForImageClassification,
46
IPEXModelForMaskedLM,
57
IPEXModelForQuestionAnswering,
68
IPEXModelForSequenceClassification,

optimum/intel/ipex/modeling_base.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from transformers import (
2626
AutoConfig,
2727
AutoModel,
28+
AutoModelForAudioClassification,
2829
AutoModelForCausalLM,
30+
AutoModelForImageClassification,
2931
AutoModelForMaskedLM,
3032
AutoModelForQuestionAnswering,
3133
AutoModelForSequenceClassification,
@@ -196,14 +198,24 @@ class IPEXModelForSequenceClassification(IPEXModel):
196198
export_feature = "text-classification"
197199

198200

201+
class IPEXModelForTokenClassification(IPEXModel):
202+
auto_model_class = AutoModelForTokenClassification
203+
export_feature = "token-classification"
204+
205+
199206
class IPEXModelForMaskedLM(IPEXModel):
200207
auto_model_class = AutoModelForMaskedLM
201208
export_feature = "fill-mask"
202209

203210

204-
class IPEXModelForTokenClassification(IPEXModel):
205-
auto_model_class = AutoModelForTokenClassification
206-
export_feature = "token-classification"
211+
class IPEXModelForImageClassification(IPEXModel):
212+
auto_model_class = AutoModelForImageClassification
213+
export_feature = "image-classification"
214+
215+
216+
class IPEXModelForAudioClassification(IPEXModel):
217+
auto_model_class = AutoModelForAudioClassification
218+
export_feature = "audio-classification"
207219

208220

209221
class IPEXModelForQuestionAnswering(IPEXModel):

optimum/intel/utils/dummy_ipex_objects.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def __init__(self, *args, **kwargs):
2222
requires_backends(self, ["ipex"])
2323

2424

25+
class IPEXModel(metaclass=DummyObject):
26+
_backends = ["ipex"]
27+
28+
def __init__(self, *args, **kwargs):
29+
requires_backends(self, ["ipex"])
30+
31+
@classmethod
32+
def from_pretrained(cls, *args, **kwargs):
33+
requires_backends(cls, ["ipex"])
34+
35+
2536
class IPEXModelForSequenceClassification(metaclass=DummyObject):
2637
_backends = ["ipex"]
2738

@@ -77,7 +88,29 @@ def from_pretrained(cls, *args, **kwargs):
7788
requires_backends(cls, ["ipex"])
7889

7990

80-
class IPEXModel(metaclass=DummyObject):
91+
class IPEXModelForMaskedLM(metaclass=DummyObject):
92+
_backends = ["ipex"]
93+
94+
def __init__(self, *args, **kwargs):
95+
requires_backends(self, ["ipex"])
96+
97+
@classmethod
98+
def from_pretrained(cls, *args, **kwargs):
99+
requires_backends(cls, ["ipex"])
100+
101+
102+
class IPEXModelForImageClassification(metaclass=DummyObject):
103+
_backends = ["ipex"]
104+
105+
def __init__(self, *args, **kwargs):
106+
requires_backends(self, ["ipex"])
107+
108+
@classmethod
109+
def from_pretrained(cls, *args, **kwargs):
110+
requires_backends(cls, ["ipex"])
111+
112+
113+
class IPEXModelForAudioClassification(metaclass=DummyObject):
81114
_backends = ["ipex"]
82115

83116
def __init__(self, *args, **kwargs):

tests/ipex/test_modeling.py

+9-40
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
3131
from optimum.intel import (
3232
IPEXModel,
33+
IPEXModelForAudioClassification,
3334
IPEXModelForCausalLM,
35+
IPEXModelForImageClassification,
36+
IPEXModelForMaskedLM,
3437
IPEXModelForQuestionAnswering,
3538
IPEXModelForSequenceClassification,
3639
IPEXModelForTokenClassification,
@@ -110,7 +113,9 @@ def test_compare_to_transformers(self, model_arch):
110113
transformers_outputs = transformers_model(**tokens)
111114
outputs = ipex_model(**tokens)
112115
# Compare tensor outputs
113-
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
116+
for output_name in {"logits", "last_hidden_state"}:
117+
if output_name in transformers_outputs:
118+
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4))
114119

115120
@parameterized.expand(SUPPORTED_ARCHITECTURES)
116121
def test_pipeline(self, model_arch):
@@ -119,7 +124,7 @@ def test_pipeline(self, model_arch):
119124
tokenizer = AutoTokenizer.from_pretrained(model_id)
120125
pipe = pipeline(self.IPEX_MODEL_CLASS.export_feature, model=model, tokenizer=tokenizer)
121126
text = "This restaurant is awesome"
122-
outputs = pipe(text)
127+
_ = pipe(text)
123128

124129
self.assertEqual(pipe.device, model.device)
125130

@@ -132,44 +137,8 @@ class IPEXModelForTokenClassificationTest(IPEXModelTest):
132137
IPEX_MODEL_CLASS = IPEXModelForSequenceClassification
133138

134139

135-
class IPEXModelForQuestionAnsweringTest(unittest.TestCase):
136-
SUPPORTED_ARCHITECTURES = (
137-
"bert",
138-
"distilbert",
139-
"roberta",
140-
)
141-
142-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
143-
def test_compare_to_transformers(self, model_arch):
144-
model_id = MODEL_NAMES[model_arch]
145-
set_seed(SEED)
146-
ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True)
147-
self.assertIsInstance(ipex_model.config, PretrainedConfig)
148-
transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id)
149-
tokenizer = AutoTokenizer.from_pretrained(model_id)
150-
inputs = "This is a sample input"
151-
tokens = tokenizer(inputs, return_tensors="pt")
152-
with torch.no_grad():
153-
transformers_outputs = transformers_model(**tokens)
154-
outputs = ipex_model(**tokens)
155-
self.assertIn("start_logits", outputs)
156-
self.assertIn("end_logits", outputs)
157-
# Compare tensor outputs
158-
self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4))
159-
self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4))
160-
161-
@parameterized.expand(SUPPORTED_ARCHITECTURES)
162-
def test_pipeline(self, model_arch):
163-
model_id = MODEL_NAMES[model_arch]
164-
model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True)
165-
tokenizer = AutoTokenizer.from_pretrained(model_id)
166-
pipe = pipeline("question-answering", model=model, tokenizer=tokenizer)
167-
question = "What's my name?"
168-
context = "My Name is Sasha and I live in Lyon."
169-
outputs = pipe(question, context)
170-
self.assertEqual(pipe.device, model.device)
171-
self.assertGreaterEqual(outputs["score"], 0.0)
172-
self.assertIsInstance(outputs["answer"], str)
140+
class IPEXModelForMaskedLMTest(IPEXModelTest):
141+
IPEX_MODEL_CLASS = IPEXModelForMaskedLM
173142

174143

175144
class IPEXModelForQuestionAnsweringTest(unittest.TestCase):

0 commit comments

Comments
 (0)