Skip to content

Commit 8ee487d

Browse files
authored
Automatic torch.autocast for IPEXModel (#542)
* Handle autocast in IPEXModel.forward * Handle missing torch_dtype in config
1 parent 398450d commit 8ee487d

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

optimum/intel/ipex/modeling_base.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
OptimizedModel.__init__(self, model=model, config=config)
7070
# To do: add XPU support
7171
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
72+
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
7273
self.model.to(self._device)
7374
self.model_save_dir = model_save_dir
7475

@@ -188,7 +189,7 @@ def forward(
188189
if "token_type_ids" in self.input_names:
189190
inputs["token_type_ids"] = token_type_ids
190191

191-
outputs = self.model(**inputs)
192+
outputs = self._call_model(**inputs)
192193
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
193194

194195
def eval(self):
@@ -199,6 +200,10 @@ def eval(self):
199200
def device(self) -> torch.device:
200201
return self._device
201202

203+
@property
204+
def dtype(self) -> torch.dtype:
205+
return self._dtype
206+
202207
def to(self, device: Union[torch.device, str]):
203208
self._device = device if isinstance(device, torch.device) else torch.device(device)
204209
self.model.to(self._device)
@@ -207,6 +212,14 @@ def to(self, device: Union[torch.device, str]):
207212
def can_generate(self):
208213
return isinstance(self, GenerationMixin)
209214

215+
def _call_model(self, *args, **kwargs):
216+
try:
217+
with torch.autocast(self.device.type, self.dtype):
218+
out = self.model(*args, **kwargs)
219+
except RuntimeError:
220+
out = self.model(*args, **kwargs)
221+
return out
222+
210223

211224
class IPEXModelForSequenceClassification(IPEXModel):
212225
auto_model_class = AutoModelForSequenceClassification
@@ -236,7 +249,7 @@ def forward(
236249
"pixel_values": pixel_values,
237250
}
238251

239-
outputs = self.model(**inputs)
252+
outputs = self._call_model(**inputs)
240253
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
241254

242255

@@ -257,7 +270,7 @@ def forward(
257270
if "attention_mask" in self.input_names:
258271
inputs["attention_mask"] = attention_mask
259272

260-
outputs = self.model(**inputs)
273+
outputs = self._call_model(**inputs)
261274
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
262275

263276

@@ -266,7 +279,7 @@ class IPEXModelForQuestionAnswering(IPEXModel):
266279
export_feature = "question-answering"
267280

268281
def forward(self, *args, **kwargs):
269-
outputs = self.model(*args, **kwargs)
282+
outputs = self._call_model(*args, **kwargs)
270283
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
271284
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
272285
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -287,7 +300,7 @@ def __init__(
287300
super().__init__(model, config, model_save_dir=model_save_dir)
288301

289302
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
290-
self.model_dtype = kwargs.get("model_dtype", None)
303+
self.model_dtype = kwargs.get("model_dtype", self.dtype)
291304
self.use_cache = "past_key_values" in self.input_names
292305

293306
if use_cache ^ self.use_cache:
@@ -377,7 +390,7 @@ def forward(
377390
inputs["past_key_values"] = past_key_values
378391

379392
# 2. Model forward
380-
outputs = self.model(**inputs)
393+
outputs = self._call_model(**inputs)
381394

382395
# 3. Process model outputs
383396
if isinstance(outputs, (list, tuple)):

0 commit comments

Comments
 (0)