File tree 1 file changed +15
-3
lines changed
1 file changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -290,9 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel):
290
290
auto_model_class = AutoModelForQuestionAnswering
291
291
export_feature = "question-answering"
292
292
293
- @wraps (IPEXModel .forward )
294
- def forward (self , * args , ** kwargs ):
295
- outputs = super ().forward (* args , ** kwargs )
293
+ def forward (self ,
294
+ input_ids : torch .Tensor ,
295
+ attention_mask : torch .Tensor ,
296
+ token_type_ids : torch .Tensor = None ,
297
+ ** kwargs ,
298
+ ):
299
+ inputs = {
300
+ "input_ids" : input_ids ,
301
+ "attention_mask" : attention_mask ,
302
+ }
303
+
304
+ if "token_type_ids" in self .input_names :
305
+ inputs ["token_type_ids" ] = token_type_ids
306
+
307
+ outputs = self ._call_model (** inputs )
296
308
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
297
309
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
298
310
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
You can’t perform that action at this time.
0 commit comments