Skip to content

Commit 6e79be1

Browse files
authored
Add IPEX models for audio and image classification tasks (#536)
* add test * format * Add image classification task * Add test
1 parent 6bf5fbc commit 6e79be1

File tree

6 files changed

+235
-24
lines changed

6 files changed

+235
-24
lines changed

optimum/intel/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
"IPEXModelForMaskedLM",
4949
"IPEXModelForTokenClassification",
5050
"IPEXModelForQuestionAnswering",
51+
"IPEXModelForImageClassification",
52+
"IPEXModelForAudioClassification",
53+
"IPEXModel",
5154
]
5255

53-
5456
try:
5557
if not (is_openvino_available() and is_nncf_available()):
5658
raise OptionalDependencyNotAvailable()
@@ -159,7 +161,10 @@
159161
from .utils.dummy_ipex_objects import *
160162
else:
161163
from .ipex import (
164+
IPEXModel,
165+
IPEXModelForAudioClassification,
162166
IPEXModelForCausalLM,
167+
IPEXModelForImageClassification,
163168
IPEXModelForMaskedLM,
164169
IPEXModelForQuestionAnswering,
165170
IPEXModelForSequenceClassification,

optimum/intel/generation/modeling.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals
6666

6767
def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
6868
model_inputs = prepare_jit_inputs(model, task, use_cache)
69-
model.config.return_dict = False
69+
model.config.return_dict = task not in {"text-generation", "audio-classification"}
7070
# check if the model_inputs is correct.
7171
model(**model_inputs)
7272

7373
torch._C._jit_set_texpr_fuser_enabled(False)
74-
if "past_key_values" in model_inputs.keys():
75-
model.config.return_dict = False
7674
if is_torch_version(">=", "2.1.0"):
7775
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
7876
else:

optimum/intel/ipex/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from optimum.intel.ipex.modeling_base import (
2+
IPEXModel,
3+
IPEXModelForAudioClassification,
24
IPEXModelForCausalLM,
5+
IPEXModelForImageClassification,
36
IPEXModelForMaskedLM,
47
IPEXModelForQuestionAnswering,
58
IPEXModelForSequenceClassification,

optimum/intel/ipex/modeling_base.py

+62-8
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from transformers import (
2626
AutoConfig,
2727
AutoModel,
28+
AutoModelForAudioClassification,
2829
AutoModelForCausalLM,
30+
AutoModelForImageClassification,
2931
AutoModelForMaskedLM,
3032
AutoModelForQuestionAnswering,
3133
AutoModelForSequenceClassification,
@@ -68,6 +70,9 @@ def __init__(
6870
self.model.to(self._device)
6971
self.model_save_dir = model_save_dir
7072

73+
self.input_names = {
74+
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
75+
}
7176
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
7277
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
7378
AutoConfig.register(self.base_model_prefix, AutoConfig)
@@ -170,8 +175,22 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
170175
output_path = os.path.join(save_directory, WEIGHTS_NAME)
171176
torch.jit.save(self.model, output_path)
172177

173-
def forward(self, *args, **kwargs):
174-
outputs = self.model(*args, **kwargs)
178+
def forward(
179+
self,
180+
input_ids: torch.Tensor,
181+
attention_mask: torch.Tensor,
182+
token_type_ids: torch.Tensor = None,
183+
**kwargs,
184+
):
185+
inputs = {
186+
"input_ids": input_ids,
187+
"attention_mask": attention_mask,
188+
}
189+
190+
if "token_type_ids" in self.input_names:
191+
inputs["token_type_ids"] = token_type_ids
192+
193+
outputs = self.model(**inputs)
175194
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
176195

177196
def eval(self):
@@ -196,14 +215,52 @@ class IPEXModelForSequenceClassification(IPEXModel):
196215
export_feature = "text-classification"
197216

198217

218+
class IPEXModelForTokenClassification(IPEXModel):
219+
auto_model_class = AutoModelForTokenClassification
220+
export_feature = "token-classification"
221+
222+
199223
class IPEXModelForMaskedLM(IPEXModel):
200224
auto_model_class = AutoModelForMaskedLM
201225
export_feature = "fill-mask"
202226

203227

204-
class IPEXModelForTokenClassification(IPEXModel):
205-
auto_model_class = AutoModelForTokenClassification
206-
export_feature = "token-classification"
228+
class IPEXModelForImageClassification(IPEXModel):
229+
auto_model_class = AutoModelForImageClassification
230+
export_feature = "image-classification"
231+
232+
def forward(
233+
self,
234+
pixel_values: torch.Tensor,
235+
**kwargs,
236+
):
237+
inputs = {
238+
"pixel_values": pixel_values,
239+
}
240+
241+
outputs = self.model(**inputs)
242+
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
243+
244+
245+
class IPEXModelForAudioClassification(IPEXModel):
246+
auto_model_class = AutoModelForAudioClassification
247+
export_feature = "audio-classification"
248+
249+
def forward(
250+
self,
251+
input_values: torch.Tensor,
252+
attention_mask: torch.Tensor = None,
253+
**kwargs,
254+
):
255+
inputs = {
256+
"input_values": input_values,
257+
}
258+
259+
if "attention_mask" in self.input_names:
260+
inputs["attention_mask"] = attention_mask
261+
262+
outputs = self.model(**inputs)
263+
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
207264

208265

209266
class IPEXModelForQuestionAnswering(IPEXModel):
@@ -233,9 +290,6 @@ def __init__(
233290

234291
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
235292
self.model_dtype = kwargs.get("model_dtype", None)
236-
self.input_names = {
237-
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
238-
}
239293
self.use_cache = "past_key_values" in self.input_names
240294

241295
if use_cache ^ self.use_cache:

optimum/intel/utils/dummy_ipex_objects.py

+33
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def __init__(self, *args, **kwargs):
2222
requires_backends(self, ["ipex"])
2323

2424

25+
class IPEXModel(metaclass=DummyObject):
26+
_backends = ["ipex"]
27+
28+
def __init__(self, *args, **kwargs):
29+
requires_backends(self, ["ipex"])
30+
31+
@classmethod
32+
def from_pretrained(cls, *args, **kwargs):
33+
requires_backends(cls, ["ipex"])
34+
35+
2536
class IPEXModelForSequenceClassification(metaclass=DummyObject):
2637
_backends = ["ipex"]
2738

@@ -75,3 +86,25 @@ def __init__(self, *args, **kwargs):
7586
@classmethod
7687
def from_pretrained(cls, *args, **kwargs):
7788
requires_backends(cls, ["ipex"])
89+
90+
91+
class IPEXModelForImageClassification(metaclass=DummyObject):
92+
_backends = ["ipex"]
93+
94+
def __init__(self, *args, **kwargs):
95+
requires_backends(self, ["ipex"])
96+
97+
@classmethod
98+
def from_pretrained(cls, *args, **kwargs):
99+
requires_backends(cls, ["ipex"])
100+
101+
102+
class IPEXModelForAudioClassification(metaclass=DummyObject):
103+
_backends = ["ipex"]
104+
105+
def __init__(self, *args, **kwargs):
106+
requires_backends(self, ["ipex"])
107+
108+
@classmethod
109+
def from_pretrained(cls, *args, **kwargs):
110+
requires_backends(cls, ["ipex"])

0 commit comments

Comments
 (0)