@@ -58,6 +58,7 @@ class IPEXModel(OptimizedModel):
58
58
export_feature = "feature-extraction"
59
59
base_model_prefix = "ipex_model"
60
60
main_input_name = "input_ids"
61
+ output_name = "last_hidden_state"
61
62
62
63
def __init__ (
63
64
self ,
@@ -193,7 +194,12 @@ def forward(
193
194
inputs ["token_type_ids" ] = token_type_ids
194
195
195
196
outputs = self ._call_model (** inputs )
196
- return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (last_hidden_state = outputs [0 ])
197
+ if isinstance (outputs , dict ):
198
+ model_output = ModelOutput (** outputs )
199
+ else :
200
+ model_output = ModelOutput ()
201
+ model_output [self .output_name ] = outputs [0 ]
202
+ return model_output
197
203
198
204
def eval (self ):
199
205
self .model .eval ()
@@ -235,16 +241,19 @@ def _init_warmup(self):
235
241
class IPEXModelForSequenceClassification (IPEXModel ):
236
242
auto_model_class = AutoModelForSequenceClassification
237
243
export_feature = "text-classification"
244
+ output_name = "logits"
238
245
239
246
240
247
class IPEXModelForTokenClassification (IPEXModel ):
241
248
auto_model_class = AutoModelForTokenClassification
242
249
export_feature = "token-classification"
250
+ output_name = "logits"
243
251
244
252
245
253
class IPEXModelForMaskedLM (IPEXModel ):
246
254
auto_model_class = AutoModelForMaskedLM
247
255
export_feature = "fill-mask"
256
+ output_name = "logits"
248
257
249
258
250
259
class IPEXModelForImageClassification (IPEXModel ):
0 commit comments