@@ -69,6 +69,7 @@ def __init__(
69
69
OptimizedModel .__init__ (self , model = model , config = config )
70
70
# To do: add XPU support
71
71
self ._device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
72
+ self ._dtype = self .config .torch_dtype if self .config .torch_dtype is not None else torch .float32
72
73
self .model .to (self ._device )
73
74
self .model_save_dir = model_save_dir
74
75
@@ -188,7 +189,7 @@ def forward(
188
189
if "token_type_ids" in self .input_names :
189
190
inputs ["token_type_ids" ] = token_type_ids
190
191
191
- outputs = self .model (** inputs )
192
+ outputs = self ._call_model (** inputs )
192
193
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
193
194
194
195
def eval (self ):
@@ -199,6 +200,10 @@ def eval(self):
199
200
def device (self ) -> torch .device :
200
201
return self ._device
201
202
203
+ @property
204
+ def dtype (self ) -> torch .dtype :
205
+ return self ._dtype
206
+
202
207
def to (self , device : Union [torch .device , str ]):
203
208
self ._device = device if isinstance (device , torch .device ) else torch .device (device )
204
209
self .model .to (self ._device )
@@ -207,6 +212,14 @@ def to(self, device: Union[torch.device, str]):
207
212
def can_generate (self ):
208
213
return isinstance (self , GenerationMixin )
209
214
215
+ def _call_model (self , * args , ** kwargs ):
216
+ try :
217
+ with torch .autocast (self .device .type , self .dtype ):
218
+ out = self .model (* args , ** kwargs )
219
+ except RuntimeError :
220
+ out = self .model (* args , ** kwargs )
221
+ return out
222
+
210
223
211
224
class IPEXModelForSequenceClassification (IPEXModel ):
212
225
auto_model_class = AutoModelForSequenceClassification
@@ -236,7 +249,7 @@ def forward(
236
249
"pixel_values" : pixel_values ,
237
250
}
238
251
239
- outputs = self .model (** inputs )
252
+ outputs = self ._call_model (** inputs )
240
253
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
241
254
242
255
@@ -257,7 +270,7 @@ def forward(
257
270
if "attention_mask" in self .input_names :
258
271
inputs ["attention_mask" ] = attention_mask
259
272
260
- outputs = self .model (** inputs )
273
+ outputs = self ._call_model (** inputs )
261
274
return ModelOutput (** outputs ) if isinstance (outputs , dict ) else ModelOutput (logits = outputs [0 ])
262
275
263
276
@@ -266,7 +279,7 @@ class IPEXModelForQuestionAnswering(IPEXModel):
266
279
export_feature = "question-answering"
267
280
268
281
def forward (self , * args , ** kwargs ):
269
- outputs = self .model (* args , ** kwargs )
282
+ outputs = self ._call_model (* args , ** kwargs )
270
283
start_logits = outputs ["start_logits" ] if isinstance (outputs , dict ) else outputs [0 ]
271
284
end_logits = outputs ["end_logits" ] if isinstance (outputs , dict ) else outputs [1 ]
272
285
return ModelOutput (start_logits = start_logits , end_logits = end_logits )
@@ -287,7 +300,7 @@ def __init__(
287
300
super ().__init__ (model , config , model_save_dir = model_save_dir )
288
301
289
302
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
290
- self .model_dtype = kwargs .get ("model_dtype" , None )
303
+ self .model_dtype = kwargs .get ("model_dtype" , self . dtype )
291
304
self .use_cache = "past_key_values" in self .input_names
292
305
293
306
if use_cache ^ self .use_cache :
@@ -377,7 +390,7 @@ def forward(
377
390
inputs ["past_key_values" ] = past_key_values
378
391
379
392
# 2. Model forward
380
- outputs = self .model (** inputs )
393
+ outputs = self ._call_model (** inputs )
381
394
382
395
# 3. Process model outputs
383
396
if isinstance (outputs , (list , tuple )):
0 commit comments