25
25
from transformers import (
26
26
AutoConfig ,
27
27
AutoModel ,
28
+ AutoModelForAudioClassification ,
28
29
AutoModelForCausalLM ,
30
+ AutoModelForImageClassification ,
29
31
AutoModelForMaskedLM ,
30
32
AutoModelForQuestionAnswering ,
31
33
AutoModelForSequenceClassification ,
@@ -68,6 +70,9 @@ def __init__(
68
70
self .model .to (self ._device )
69
71
self .model_save_dir = model_save_dir
70
72
73
+ self .input_names = {
74
+ inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
75
+ }
71
76
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
72
77
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
73
78
AutoConfig .register (self .base_model_prefix , AutoConfig )
@@ -170,8 +175,22 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
170
175
output_path = os .path .join (save_directory , WEIGHTS_NAME )
171
176
torch .jit .save (self .model , output_path )
172
177
173
- def forward (self , * args , ** kwargs ):
174
- outputs = self .model (* args , ** kwargs )
178
+ def forward (
179
+ self ,
180
+ input_ids : torch .Tensor ,
181
+ attention_mask : torch .Tensor ,
182
+ token_type_ids : torch .Tensor = None ,
183
+ ** kwargs ,
184
+ ):
185
+ inputs = {
186
+ "input_ids" : input_ids ,
187
+ "attention_mask" : attention_mask ,
188
+ }
189
+
190
+ if "token_type_ids" in self .input_names :
191
+ inputs ["token_type_ids" ] = token_type_ids
192
+
193
+ outputs = self .model (** inputs )
175
194
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
176
195
177
196
def eval (self ):
@@ -196,14 +215,52 @@ class IPEXModelForSequenceClassification(IPEXModel):
196
215
export_feature = "text-classification"
197
216
198
217
218
+ class IPEXModelForTokenClassification (IPEXModel ):
219
+ auto_model_class = AutoModelForTokenClassification
220
+ export_feature = "token-classification"
221
+
222
+
199
223
class IPEXModelForMaskedLM (IPEXModel ):
200
224
auto_model_class = AutoModelForMaskedLM
201
225
export_feature = "fill-mask"
202
226
203
227
204
- class IPEXModelForTokenClassification (IPEXModel ):
205
- auto_model_class = AutoModelForTokenClassification
206
- export_feature = "token-classification"
228
+ class IPEXModelForImageClassification (IPEXModel ):
229
+ auto_model_class = AutoModelForImageClassification
230
+ export_feature = "image-classification"
231
+
232
+ def forward (
233
+ self ,
234
+ pixel_values : torch .Tensor ,
235
+ ** kwargs ,
236
+ ):
237
+ inputs = {
238
+ "pixel_values" : pixel_values ,
239
+ }
240
+
241
+ outputs = self .model (** inputs )
242
+ return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
243
+
244
+
245
+ class IPEXModelForAudioClassification (IPEXModel ):
246
+ auto_model_class = AutoModelForAudioClassification
247
+ export_feature = "audio-classification"
248
+
249
+ def forward (
250
+ self ,
251
+ input_values : torch .Tensor ,
252
+ attention_mask : torch .Tensor = None ,
253
+ ** kwargs ,
254
+ ):
255
+ inputs = {
256
+ "input_values" : input_values ,
257
+ }
258
+
259
+ if "attention_mask" in self .input_names :
260
+ inputs ["attention_mask" ] = attention_mask
261
+
262
+ outputs = self .model (** inputs )
263
+ return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
207
264
208
265
209
266
class IPEXModelForQuestionAnswering (IPEXModel ):
@@ -233,9 +290,6 @@ def __init__(
233
290
234
291
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
235
292
self .model_dtype = kwargs .get ("model_dtype" , None )
236
- self .input_names = {
237
- inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
238
- }
239
293
self .use_cache = "past_key_values" in self .input_names
240
294
241
295
if use_cache ^ self .use_cache :
0 commit comments