Skip to content

Commit 151712d

Browse files
committed
rm autocast in model
1 parent e05557a commit 151712d

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

optimum/intel/ipex/modeling_base.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def forward(
193193
if "token_type_ids" in self.input_names:
194194
inputs["token_type_ids"] = token_type_ids
195195

196-
outputs = self._call_model(**inputs)
196+
outputs = self.model(**inputs)
197197
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
198198

199199
def eval(self):
@@ -216,14 +216,6 @@ def to(self, device: Union[torch.device, str]):
216216
def can_generate(self):
217217
return isinstance(self, GenerationMixin)
218218

219-
def _call_model(self, *args, **kwargs):
220-
try:
221-
with torch.autocast(self.device.type, self.dtype):
222-
out = self.model(*args, **kwargs)
223-
except RuntimeError:
224-
out = self.model(*args, **kwargs)
225-
return out
226-
227219
def _init_warmup(self):
228220
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
229221
# the results of the compute are unpredictable
@@ -261,7 +253,7 @@ def forward(
261253
"pixel_values": pixel_values,
262254
}
263255

264-
outputs = self._call_model(**inputs)
256+
outputs = self.model(**inputs)
265257
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
266258

267259

@@ -282,7 +274,7 @@ def forward(
282274
if "attention_mask" in self.input_names:
283275
inputs["attention_mask"] = attention_mask
284276

285-
outputs = self._call_model(**inputs)
277+
outputs = self.model(**inputs)
286278
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
287279

288280

@@ -305,7 +297,7 @@ def forward(
305297
if "token_type_ids" in self.input_names:
306298
inputs["token_type_ids"] = token_type_ids
307299

308-
outputs = self._call_model(**inputs)
300+
outputs = self.model(**inputs)
309301
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
310302
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
311303
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -451,7 +443,7 @@ def forward(
451443
inputs["past_key_values"] = past_key_values
452444

453445
# 2. Model forward
454-
outputs = self._call_model(**inputs)
446+
outputs = self.model(**inputs)
455447

456448
# 3. Process model outputs
457449
if isinstance(outputs, (list, tuple)):

0 commit comments

Comments
 (0)