Skip to content

Commit 4c73016

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Dynamo] Enable torch._dynamo.config.suppress_errors by default (pytorch#105307)
Summary: We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out. But at the same time, we should fix these issues if they are bugs. We will: * 1/ log warnings in OSS; * 2/ log warnings and write them into Scuba in fbcode; to prevent us from ignoring these issues. Test Plan: Manual test Differential Revision: D47506314 Pull Request resolved: pytorch#105307 Approved by: https://github.com/jansel
1 parent de8bd10 commit 4c73016

File tree

8 files changed

+64
-49
lines changed

8 files changed

+64
-49
lines changed

test/dynamo/test_logging.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,20 @@ def test_dynamo_debug_default_off_artifacts(self, records):
112112
self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
113113
self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)
114114

115-
@make_logging_test(dynamo=logging.ERROR)
115+
@make_logging_test()
116116
def test_dynamo_error(self, records):
117117
try:
118118
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
119119
fn_opt(*ARGS)
120120
except Exception:
121121
pass
122-
self.assertEqual(len(records), 1)
122+
self.assertEqual(len(records), 2)
123123

124124
test_aot = within_range_record_test(2, 6, aot=logging.INFO)
125125
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
126126
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
127127

128-
@make_logging_test(dynamo=logging.ERROR)
128+
@make_logging_test()
129129
def test_inductor_error(self, records):
130130
exitstack = contextlib.ExitStack()
131131
import torch._inductor.lowering
@@ -148,7 +148,7 @@ def throw(x):
148148
fn_opt(*ARGS)
149149
except Exception:
150150
pass
151-
self.assertEqual(len(records), 1)
151+
self.assertEqual(len(records), 2)
152152
self.assertIsInstance(records[0].msg, str)
153153

154154
exitstack.close()

test/functorch/test_eager_transforms.py

+2
Original file line numberDiff line numberDiff line change
@@ -4738,6 +4738,8 @@ class TestCompileTransforms(TestCase):
47384738
# torch.compile is not supported on Windows
47394739
# Triton only supports GPU with SM70 or later.
47404740
@expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater))
4741+
@torch._dynamo.config.patch(suppress_errors=False)
4742+
@skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
47414743
def test_compile_vmap_hessian(self, device):
47424744
# The model and inputs are a smaller version
47434745
# of code at benchmark repo:

test/inductor/test_cpu_repro.py

-9
Original file line numberDiff line numberDiff line change
@@ -1891,15 +1891,6 @@ def fn(x):
18911891
self.assertTrue(same(fn(x), opt_fn(x)))
18921892
assert metrics.generated_cpp_vec_kernel_count == 2
18931893

1894-
def test_invalid_index_of_empty_tensor(self):
1895-
def fn(a):
1896-
b = a[[0]]
1897-
return b
1898-
1899-
a = torch.tensor([])
1900-
with self.assertRaises(RuntimeError):
1901-
torch.compile(fn)(a)
1902-
19031894
def test_ir_node_str(self):
19041895
@torch.compile
19051896
def fn(x: torch.Tensor) -> torch.Tensor:

torch/_dynamo/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
# This is a good way to get your model to work one way or another, but you may
9999
# lose optimization opportunities this way. Devs, if your benchmark model is failing
100100
# this way, you should figure out why instead of suppressing it.
101-
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
101+
suppress_errors = os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", "1") == "1"
102102

103103
# Record and write an execution record of the current frame to a file
104104
# if an exception is encountered

torch/_dynamo/convert_frame.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
augment_exc_message,
3434
BackendCompilerFailed,
3535
format_error_msg,
36+
format_error_msg_verbose,
3637
InternalTorchDynamoError,
3738
TorchRuntimeError,
3839
unimplemented,
@@ -215,7 +216,19 @@ def exception_handler(e, code, frame=None):
215216
# Only log the exception if we are going to suppress it
216217
# if aren't suppressing it, a higher level except block will handle it
217218
if config.suppress_errors:
218-
log.error(format_error_msg(e, code, record_filename, frame))
219+
if config.is_fbcode():
220+
from torch._dynamo.fb.logging import ( # type: ignore[import]
221+
log_dynamo_suppress_errors,
222+
)
223+
224+
error_msg = format_error_msg_verbose(e, code, record_filename, frame)
225+
log_dynamo_suppress_errors(
226+
code.co_name, code.co_filename, code.co_firstlineno, error_msg
227+
)
228+
else:
229+
error_msg = format_error_msg(e, code, record_filename, frame)
230+
231+
log.warning(error_msg)
219232

220233

221234
FRAME_COUNTER = 0
@@ -551,7 +564,7 @@ def _convert_frame(
551564
except Exception:
552565
if not config.suppress_errors:
553566
raise
554-
log.info("converting frame raised error, suppressing error")
567+
log.warning("converting frame raised error, suppressing error")
555568
return None
556569

557570
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]

torch/_dynamo/eval_frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def remove_from_cache(f):
160160
elif hasattr(getattr(f, "forward", None), "__code__"):
161161
reset_code(f.forward.__code__)
162162
else:
163-
from . import reset
163+
from . import reset # type: ignore[attr-defined]
164164

165165
reset()
166166
log.warning("could not determine __code__ for %s", f)
@@ -591,7 +591,7 @@ def toy_example(a, b):
591591
@patch("torch._dynamo.symbolic_convert.explain", True)
592592
def explain(f, *args, **kwargs):
593593
# TODO(voz): Do we want a decorator for this?
594-
from . import reset
594+
from . import reset # type: ignore[attr-defined]
595595

596596
reset()
597597

torch/_dynamo/exc.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -226,39 +226,44 @@ def filter_stack(stack):
226226
return user_stack
227227

228228

229-
def format_error_msg(exc, code, record_filename=None, frame=None):
230-
msg = os.linesep * 2
229+
def format_error_msg_verbose(exc, code, record_filename=None, frame=None):
230+
msg = str(
231+
format_bytecode(
232+
"WON'T CONVERT",
233+
code.co_name,
234+
code.co_filename,
235+
code.co_firstlineno,
236+
code,
237+
)
238+
)
239+
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
240+
msg += format_exc()
241+
if hasattr(exc, "real_stack"):
242+
msg += (
243+
"\n"
244+
+ "=" * 10
245+
+ " The above exception occurred while processing the following code "
246+
+ "=" * 10
247+
+ "\n\n"
248+
)
249+
stack_above_dynamo = []
250+
if frame is not None:
251+
stack_above_dynamo = filter_stack(extract_stack(frame))
231252

232-
if config.verbose:
233-
msg = str(
234-
format_bytecode(
235-
"WON'T CONVERT",
236-
code.co_name,
237-
code.co_filename,
238-
code.co_firstlineno,
239-
code,
240-
)
253+
msg += "".join(
254+
format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
241255
)
242-
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
243-
msg += format_exc()
244-
if hasattr(exc, "real_stack"):
245-
msg += (
246-
"\n"
247-
+ "=" * 10
248-
+ " The above exception occurred while processing the following code "
249-
+ "=" * 10
250-
+ "\n\n"
251-
)
252-
stack_above_dynamo = []
253-
if frame is not None:
254-
stack_above_dynamo = filter_stack(extract_stack(frame))
256+
msg += "\n"
257+
msg += "=" * 10
258+
259+
return msg
255260

256-
msg += "".join(
257-
format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
258-
)
259-
msg += "\n"
260-
msg += "=" * 10
261261

262+
def format_error_msg(exc, code, record_filename=None, frame=None):
263+
msg = os.linesep * 2
264+
265+
if config.verbose:
266+
msg = format_error_msg_verbose(exec, code, record_filename, frame)
262267
else:
263268
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
264269
line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}"

torch/testing/_internal/logging_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,12 @@ def test_fn(self):
7575
torch._dynamo.reset()
7676
records = []
7777
# run with env var
78-
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
79-
fn(self, records)
78+
if len(kwargs) == 0:
79+
with self._handler_watcher(records):
80+
fn(self, records)
81+
else:
82+
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
83+
fn(self, records)
8084

8185
# run with API
8286
torch._dynamo.reset()

0 commit comments

Comments
 (0)