Skip to content

Commit 0667c3e

Browse files
committed
Handle autocast in IPEXModel.forward
1 parent 3b627f4 commit 0667c3e

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

optimum/intel/ipex/modeling_base.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
OptimizedModel.__init__(self, model=model, config=config)
6868
# To do: add XPU support
6969
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70+
self._dtype = self.config.torch_dtype
7071
self.model.to(self._device)
7172
self.model_save_dir = model_save_dir
7273

@@ -190,7 +191,7 @@ def forward(
190191
if "token_type_ids" in self.input_names:
191192
inputs["token_type_ids"] = token_type_ids
192193

193-
outputs = self.model(**inputs)
194+
outputs = self._call_model(**inputs)
194195
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
195196

196197
def eval(self):
@@ -201,6 +202,10 @@ def eval(self):
201202
def device(self) -> torch.device:
202203
return self._device
203204

205+
@property
206+
def dtype(self) -> torch.dtype:
207+
return self._dtype
208+
204209
def to(self, device: Union[torch.device, str]):
205210
self._device = device if isinstance(device, torch.device) else torch.device(device)
206211
self.model.to(self._device)
@@ -209,6 +214,14 @@ def to(self, device: Union[torch.device, str]):
209214
def can_generate(self):
210215
return isinstance(self, GenerationMixin)
211216

217+
def _call_model(self, *args, **kwargs):
218+
try:
219+
with torch.autocast(self.device.type, self.dtype):
220+
out = self.model(*args, **kwargs)
221+
except RuntimeError:
222+
out = self.model(*args, **kwargs)
223+
return out
224+
212225

213226
class IPEXModelForSequenceClassification(IPEXModel):
214227
auto_model_class = AutoModelForSequenceClassification
@@ -238,7 +251,7 @@ def forward(
238251
"pixel_values": pixel_values,
239252
}
240253

241-
outputs = self.model(**inputs)
254+
outputs = self._call_model(**inputs)
242255
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
243256

244257

@@ -259,7 +272,7 @@ def forward(
259272
if "attention_mask" in self.input_names:
260273
inputs["attention_mask"] = attention_mask
261274

262-
outputs = self.model(**inputs)
275+
outputs = self._call_model(**inputs)
263276
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
264277

265278

@@ -268,7 +281,7 @@ class IPEXModelForQuestionAnswering(IPEXModel):
268281
export_feature = "question-answering"
269282

270283
def forward(self, *args, **kwargs):
271-
outputs = self.model(*args, **kwargs)
284+
outputs = self._call_model(*args, **kwargs)
272285
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
273286
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
274287
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -289,7 +302,7 @@ def __init__(
289302
super().__init__(model, config, model_save_dir=model_save_dir)
290303

291304
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
292-
self.model_dtype = kwargs.get("model_dtype", None)
305+
self.model_dtype = kwargs.get("model_dtype", self.dtype)
293306
self.use_cache = "past_key_values" in self.input_names
294307

295308
if use_cache ^ self.use_cache:
@@ -367,7 +380,7 @@ def forward(
367380
inputs["past_key_values"] = past_key_values
368381

369382
# 2. Model forward
370-
outputs = self.model(**inputs)
383+
outputs = self._call_model(**inputs)
371384

372385
# 3. Process model outputs
373386
if isinstance(outputs, (list, tuple)):

0 commit comments

Comments
 (0)