Skip to content

Commit abb7b00

Browse files
committed
Fix output handling in IPEX question answering
1 parent d797cc9 commit abb7b00

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

optimum/intel/ipex/modeling_base.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
290290
auto_model_class = AutoModelForQuestionAnswering
291291
export_feature = "question-answering"
292292

293-
@wraps(IPEXModel.forward)
294-
def forward(self, *args, **kwargs):
295-
outputs = super().forward(*args, **kwargs)
293+
def forward(self,
294+
input_ids: torch.Tensor,
295+
attention_mask: torch.Tensor,
296+
token_type_ids: torch.Tensor = None,
297+
**kwargs,
298+
):
299+
inputs = {
300+
"input_ids": input_ids,
301+
"attention_mask": attention_mask,
302+
}
303+
304+
if "token_type_ids" in self.input_names:
305+
inputs["token_type_ids"] = token_type_ids
306+
307+
outputs = self._call_model(**inputs)
296308
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
297309
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
298310
return ModelOutput(start_logits=start_logits, end_logits=end_logits)

0 commit comments

Comments
 (0)