Skip to content

Commit 5bbec68

Browse files
Chilleepytorchmergebot
authored andcommitted
Fix usages of contextmanager without finally (pytorch#96170)
Pull Request resolved: pytorch#96170 Approved by: https://github.com/ngimel, https://github.com/malfet
1 parent 34d18c8 commit 5bbec68

File tree

15 files changed

+91
-59
lines changed

15 files changed

+91
-59
lines changed

benchmarks/tensorexpr/benchmark.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,15 @@ def cuda_pointwise_context(loop_levels, block_count, block_size):
228228
old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size()
229229
torch._C._jit_set_te_cuda_pointwise_block_size(block_size)
230230

231-
yield
232-
233-
if loop_levels:
234-
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
235-
if block_count:
236-
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
237-
if block_size:
238-
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
231+
try:
232+
yield
233+
finally:
234+
if loop_levels:
235+
torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
236+
if block_count:
237+
torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
238+
if block_size:
239+
torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
239240

240241
# Auxiliary class to facilitate dynamic input shape
241242
class DynamicShape:

caffe2/python/net_printer.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def set_workspace(self, node=None, ws=None, do_copy=False):
6565
if do_copy:
6666
ws = copy(ws)
6767
self.workspace_ctx.append(ws)
68-
yield ws
69-
del self.workspace_ctx[-1]
68+
try:
69+
yield ws
70+
finally:
71+
del self.workspace_ctx[-1]
7072

7173
def define_blob(self, blob):
7274
self.workspace[blob] += 1
@@ -166,12 +168,14 @@ def context(self, text):
166168
self.add('with %s:' % text)
167169
self._indent += 4
168170
self._lines_in_context.append(0)
169-
yield
170-
if text is not None:
171-
if self._lines_in_context[-1] == 0:
172-
self.add('pass')
173-
self._indent -= 4
174-
del self._lines_in_context[-1]
171+
try:
172+
yield
173+
finally:
174+
if text is not None:
175+
if self._lines_in_context[-1] == 0:
176+
self.add('pass')
177+
self._indent -= 4
178+
del self._lines_in_context[-1]
175179

176180
def add(self, text):
177181
self._lines_in_context[-1] += 1

test/test_spectral_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,10 @@ def plan_cache_max_size(device, n):
811811
plan_cache = torch.backends.cuda.cufft_plan_cache[device]
812812
original = plan_cache.max_size
813813
plan_cache.max_size = n
814-
yield
815-
plan_cache.max_size = original
814+
try:
815+
yield
816+
finally:
817+
plan_cache.max_size = original
816818

817819
with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
818820
self._test_fft_ifft_rfft_irfft(devices[0], dtype)

torch/_dynamo/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,6 @@ def disable_cache_limit():
10121012
try:
10131013
yield
10141014
finally:
1015-
pass
10161015
config.cache_size_limit = prior
10171016

10181017

torch/_functorch/aot_autograd.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1228,10 +1228,12 @@ def track_graph_compiling(aot_config, graph_name):
12281228
global graph_being_compiled
12291229
# TODO: Don't shove the aot_id in here; set it in the context
12301230
graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
1231-
yield
1232-
global nth_graph
1233-
nth_graph += 1
1234-
graph_being_compiled = []
1231+
try:
1232+
yield
1233+
finally:
1234+
global nth_graph
1235+
nth_graph += 1
1236+
graph_being_compiled = []
12351237

12361238

12371239
def make_boxed_func(f):

torch/_inductor/codegen/common.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,10 @@ def __init__(self, args=None):
560560
def set_current_node(self, node):
561561
prior = self.current_node
562562
self.current_node = node
563-
yield
564-
self.current_node = prior
563+
try:
564+
yield
565+
finally:
566+
self.current_node = prior
565567

566568
@contextlib.contextmanager
567569
def swap_buffers(self, lb, cb=None, sb=None):
@@ -575,11 +577,13 @@ def swap_buffers(self, lb, cb=None, sb=None):
575577
self.compute = cb
576578
self.stores = sb
577579
self.cse = cse.clone()
578-
yield
579-
self.loads = loads
580-
self.compute = compute
581-
self.stores = stores
582-
self.cse = cse
580+
try:
581+
yield
582+
finally:
583+
self.loads = loads
584+
self.compute = compute
585+
self.stores = stores
586+
self.cse = cse
583587

584588
def load(self, name: str, index: sympy.Expr):
585589
raise NotImplementedError()

torch/_inductor/codegen/triton.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -696,11 +696,13 @@ def ctx():
696696
# and write out a reduction loop
697697
self.codegen_body()
698698
self.inside_reduction = False
699-
yield
700-
if not self.persistent_reduction:
701-
# flush out any code before opening the next loop
702-
self.codegen_body()
703-
self.inside_reduction = True
699+
try:
700+
yield
701+
if not self.persistent_reduction:
702+
# flush out any code before opening the next loop
703+
self.codegen_body()
704+
finally:
705+
self.inside_reduction = True
704706

705707
return ctx()
706708

@@ -957,10 +959,12 @@ def mask_loads(self, mask):
957959
mask = self.cse.generate(self.compute, f"{mask} & {prior}")
958960

959961
self._load_mask = mask
960-
with self.swap_buffers(self.compute, self.compute):
961-
# TODO(jansel): do we need a reshape here?
962-
yield mask
963-
self._load_mask = prior
962+
try:
963+
with self.swap_buffers(self.compute, self.compute):
964+
# TODO(jansel): do we need a reshape here?
965+
yield mask
966+
finally:
967+
self._load_mask = prior
964968

965969
def load(self, name: str, index: sympy.Expr):
966970
var = self.args.input(name)

torch/_inductor/triton_ops/autotune.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def end_graph():
228228
cur_file = inspect.stack()[1].filename
229229
print(f"SUMMARY ({cur_file})")
230230
print(
231-
f"{overall_time:.2f}ms\t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
231+
f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
232232
)
233233
print()
234234

@@ -250,10 +250,12 @@ def run(self, *args, grid, stream):
250250
num_gb = get_num_bytes(*args) / 1e9
251251
gb_per_s = num_gb / (ms / 1e3)
252252

253-
collected_calls.append((kernel_name, ms, num_gb, gb_per_s))
253+
collected_calls.append((ms, num_gb, gb_per_s, kernel_name)),
254254
import colorama
255255

256-
info_str = f"{kernel_name}\t {ms:.3f}ms\t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s"
256+
info_str = (
257+
f"{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s \t {kernel_name}"
258+
)
257259
if ms > 0.012 and gb_per_s < 650:
258260
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
259261
else:

torch/_inductor/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,10 @@ def indent(self, offset=1):
444444
@contextlib.contextmanager
445445
def ctx():
446446
self._indent += offset
447-
yield
448-
self._indent -= offset
447+
try:
448+
yield
449+
finally:
450+
self._indent -= offset
449451

450452
return ctx()
451453

torch/_jit_internal.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1243,8 +1243,10 @@ def _create_named_tuple(
12431243
def _disable_emit_hooks():
12441244
hooks = torch._C._jit_get_emit_hooks()
12451245
torch._C._jit_set_emit_hooks(None, None)
1246-
yield
1247-
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1246+
try:
1247+
yield
1248+
finally:
1249+
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
12481250

12491251

12501252
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811

torch/cuda/nvtx.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,7 @@ def range(msg, *args, **kwargs):
8484
msg (str): message to associate with the range
8585
"""
8686
range_push(msg.format(*args, **kwargs))
87-
yield
88-
range_pop()
87+
try:
88+
yield
89+
finally:
90+
range_pop()

torch/distributed/elastic/multiprocessing/redirects.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def _redirect(dst):
9393

9494
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
9595
_redirect(dst)
96-
yield
97-
_redirect(orig_std)
96+
try:
97+
yield
98+
finally:
99+
_redirect(orig_std)
98100

99101

100102
redirect_stdout = partial(redirect, "stdout")

torch/profiler/itt.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,7 @@ def range(msg, *args, **kwargs):
6969
msg (str): message to associate with the range
7070
"""
7171
range_push(msg.format(*args, **kwargs))
72-
yield
73-
range_pop()
72+
try:
73+
yield
74+
finally:
75+
range_pop()

torch/serialization.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ class SourceChangeWarning(Warning):
5656
@contextmanager
5757
def mkdtemp():
5858
path = tempfile.mkdtemp()
59-
yield path
60-
shutil.rmtree(path)
59+
try:
60+
yield path
61+
finally:
62+
shutil.rmtree(path)
6163

6264

6365
_package_registry = []

torch/testing/_internal/common_distributed.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1189,11 +1189,13 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True):
11891189
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
11901190
torch._dynamo.reset()
11911191
torch._dynamo.utils.counters.clear()
1192-
yield
1193-
torch._dynamo.reset()
1194-
torch._dynamo.utils.counters.clear()
1195-
if init_pg:
1196-
c10d.destroy_process_group()
1192+
try:
1193+
yield
1194+
finally:
1195+
torch._dynamo.reset()
1196+
torch._dynamo.utils.counters.clear()
1197+
if init_pg:
1198+
c10d.destroy_process_group()
11971199

11981200

11991201
class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):

0 commit comments

Comments
 (0)