Skip to content

Commit c356aa3

Browse files
authored
Change model output parameter to last_hidden_states for IPEXModel (#589)
* change model output parameter to last_hidden_states * update ipex model testiong * update testing * add output name to ipex model
1 parent 8c95cae commit c356aa3

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

optimum/intel/ipex/modeling_base.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class IPEXModel(OptimizedModel):
5858
export_feature = "feature-extraction"
5959
base_model_prefix = "ipex_model"
6060
main_input_name = "input_ids"
61+
output_name = "last_hidden_state"
6162

6263
def __init__(
6364
self,
@@ -193,7 +194,12 @@ def forward(
193194
inputs["token_type_ids"] = token_type_ids
194195

195196
outputs = self._call_model(**inputs)
196-
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
197+
if isinstance(outputs, dict):
198+
model_output = ModelOutput(**outputs)
199+
else:
200+
model_output = ModelOutput()
201+
model_output[self.output_name] = outputs[0]
202+
return model_output
197203

198204
def eval(self):
199205
self.model.eval()
@@ -235,16 +241,19 @@ def _init_warmup(self):
235241
class IPEXModelForSequenceClassification(IPEXModel):
236242
auto_model_class = AutoModelForSequenceClassification
237243
export_feature = "text-classification"
244+
output_name = "logits"
238245

239246

240247
class IPEXModelForTokenClassification(IPEXModel):
241248
auto_model_class = AutoModelForTokenClassification
242249
export_feature = "token-classification"
250+
output_name = "logits"
243251

244252

245253
class IPEXModelForMaskedLM(IPEXModel):
246254
auto_model_class = AutoModelForMaskedLM
247255
export_feature = "fill-mask"
256+
output_name = "logits"
248257

249258

250259
class IPEXModelForImageClassification(IPEXModel):

0 commit comments

Comments
 (0)