diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 2b6b569343..9928977ead 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,6 +58,7 @@ class IPEXModel(OptimizedModel): export_feature = "feature-extraction" base_model_prefix = "ipex_model" main_input_name = "input_ids" + output_name = "last_hidden_state" def __init__( self, @@ -193,7 +194,12 @@ def forward( inputs["token_type_ids"] = token_type_ids outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) + if isinstance(outputs, dict): + model_output = ModelOutput(**outputs) + else: + model_output = ModelOutput() + model_output[self.output_name] = outputs[0] + return model_output def eval(self): self.model.eval() @@ -235,16 +241,19 @@ def _init_warmup(self): class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification export_feature = "text-classification" + output_name = "logits" class IPEXModelForTokenClassification(IPEXModel): auto_model_class = AutoModelForTokenClassification export_feature = "token-classification" + output_name = "logits" class IPEXModelForMaskedLM(IPEXModel): auto_model_class = AutoModelForMaskedLM export_feature = "fill-mask" + output_name = "logits" class IPEXModelForImageClassification(IPEXModel):