@@ -193,7 +193,7 @@ def forward(
193
193
if "token_type_ids" in self .input_names :
194
194
inputs ["token_type_ids" ] = token_type_ids
195
195
196
- outputs = self ._call_model (** inputs )
196
+ outputs = self .model (** inputs )
197
197
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
198
198
199
199
def eval (self ):
@@ -216,14 +216,6 @@ def to(self, device: Union[torch.device, str]):
216
216
def can_generate (self ):
217
217
return isinstance (self , GenerationMixin )
218
218
219
- def _call_model (self , * args , ** kwargs ):
220
- try :
221
- with torch .autocast (self .device .type , self .dtype ):
222
- out = self .model (* args , ** kwargs )
223
- except RuntimeError :
224
- out = self .model (* args , ** kwargs )
225
- return out
226
-
227
219
def _init_warmup (self ):
228
220
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
229
221
# the results of the compute are unpredictable
@@ -261,7 +253,7 @@ def forward(
261
253
"pixel_values" : pixel_values ,
262
254
}
263
255
264
- outputs = self ._call_model (** inputs )
256
+ outputs = self .model (** inputs )
265
257
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
266
258
267
259
@@ -282,7 +274,7 @@ def forward(
282
274
if "attention_mask" in self .input_names :
283
275
inputs ["attention_mask" ] = attention_mask
284
276
285
- outputs = self ._call_model (** inputs )
277
+ outputs = self .model (** inputs )
286
278
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
287
279
288
280
@@ -305,7 +297,7 @@ def forward(
305
297
if "token_type_ids" in self .input_names :
306
298
inputs ["token_type_ids" ] = token_type_ids
307
299
308
- outputs = self ._call_model (** inputs )
300
+ outputs = self .model (** inputs )
309
301
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
310
302
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
311
303
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
@@ -451,7 +443,7 @@ def forward(
451
443
inputs ["past_key_values" ] = past_key_values
452
444
453
445
# 2. Model forward
454
- outputs = self ._call_model (** inputs )
446
+ outputs = self .model (** inputs )
455
447
456
448
# 3. Process model outputs
457
449
if isinstance (outputs , (list , tuple )):
0 commit comments