@@ -67,6 +67,7 @@ def __init__(
67
67
OptimizedModel .__init__ (self , model = model , config = config )
68
68
# To do: add XPU support
69
69
self ._device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
70
+ self ._dtype = self .config .torch_dtype
70
71
self .model .to (self ._device )
71
72
self .model_save_dir = model_save_dir
72
73
@@ -190,7 +191,7 @@ def forward(
190
191
if "token_type_ids" in self .input_names :
191
192
inputs ["token_type_ids" ] = token_type_ids
192
193
193
- outputs = self .model (** inputs )
194
+ outputs = self ._call_model (** inputs )
194
195
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
195
196
196
197
def eval (self ):
@@ -201,6 +202,10 @@ def eval(self):
201
202
def device (self ) -> torch .device :
202
203
return self ._device
203
204
205
+ @property
206
+ def dtype (self ) -> torch .dtype :
207
+ return self ._dtype
208
+
204
209
def to (self , device : Union [torch .device , str ]):
205
210
self ._device = device if isinstance (device , torch .device ) else torch .device (device )
206
211
self .model .to (self ._device )
@@ -209,6 +214,14 @@ def to(self, device: Union[torch.device, str]):
209
214
def can_generate (self ):
210
215
return isinstance (self , GenerationMixin )
211
216
217
+ def _call_model (self , * args , ** kwargs ):
218
+ try :
219
+ with torch .autocast (self .device .type , self .dtype ):
220
+ out = self .model (* args , ** kwargs )
221
+ except RuntimeError :
222
+ out = self .model (* args , ** kwargs )
223
+ return out
224
+
212
225
213
226
class IPEXModelForSequenceClassification (IPEXModel ):
214
227
auto_model_class = AutoModelForSequenceClassification
@@ -238,7 +251,7 @@ def forward(
238
251
"pixel_values" : pixel_values ,
239
252
}
240
253
241
- outputs = self .model (** inputs )
254
+ outputs = self ._call_model (** inputs )
242
255
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
243
256
244
257
@@ -259,7 +272,7 @@ def forward(
259
272
if "attention_mask" in self .input_names :
260
273
inputs ["attention_mask" ] = attention_mask
261
274
262
- outputs = self .model (** inputs )
275
+ outputs = self ._call_model (** inputs )
263
276
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
264
277
265
278
@@ -268,7 +281,7 @@ class IPEXModelForQuestionAnswering(IPEXModel):
268
281
export_feature = "question-answering"
269
282
270
283
def forward (self , * args , ** kwargs ):
271
- outputs = self .model (* args , ** kwargs )
284
+ outputs = self ._call_model (* args , ** kwargs )
272
285
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
273
286
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
274
287
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
@@ -289,7 +302,7 @@ def __init__(
289
302
super ().__init__ (model , config , model_save_dir = model_save_dir )
290
303
291
304
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
292
- self .model_dtype = kwargs .get ("model_dtype" , None )
305
+ self .model_dtype = kwargs .get ("model_dtype" , self . dtype )
293
306
self .use_cache = "past_key_values" in self .input_names
294
307
295
308
if use_cache ^ self .use_cache :
@@ -367,7 +380,7 @@ def forward(
367
380
inputs ["past_key_values" ] = past_key_values
368
381
369
382
# 2. Model forward
370
- outputs = self .model (** inputs )
383
+ outputs = self ._call_model (** inputs )
371
384
372
385
# 3. Process model outputs
373
386
if isinstance (outputs , (list , tuple )):
0 commit comments