Skip to content

Commit b7b2178

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Remove useless lambdas (pytorch#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule. Pull Request resolved: pytorch#113602 Approved by: https://github.com/albanD
1 parent 2a8a742 commit b7b2178

35 files changed

+72
-78
lines changed

benchmarks/cpp/tensorexpr/bench_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def hardswish(x):
5151

5252
for op in unary_ops:
5353
x = torch.rand((1024, 1024))
54-
traced = torch.jit.trace(lambda x: op(x), (x))
54+
traced = torch.jit.trace(op, (x))
5555

5656
# Warmup.
5757
warmup_iters = 8
@@ -88,7 +88,7 @@ def test_batch_norm():
8888
x = torch.rand((n, c, h, w))
8989
y = torch.rand(c)
9090
z = torch.rand(c)
91-
traced = torch.jit.trace(lambda x, y, z: op(x, y, z), (x, y, z))
91+
traced = torch.jit.trace(op, (x, y, z))
9292

9393
# Warmup.
9494
warmup_iters = 8

benchmarks/tensorexpr/broadcast.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -268,18 +268,18 @@ def register_broadcast_ops():
268268
["div", lambda a, b: a / (b + 1e-4)],
269269
[
270270
"pow",
271-
lambda a, b: torch.pow(a, b),
272-
lambda a, b: np.power(a, b),
271+
torch.pow,
272+
np.power,
273273
], # no fuson triggered
274-
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
275-
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
274+
["max", torch.max, np.maximum],
275+
["min", torch.min, np.minimum],
276276
]
277277

278278
unary_op_list = [
279-
["erf", lambda x: torch.erf(x), lambda x: np.erf(x)],
280-
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
281-
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
282-
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
279+
["erf", torch.erf, np.erf],
280+
["exp", torch.exp, np.exp],
281+
["sin", torch.sin, np.sin],
282+
["cos", torch.cos, np.cos],
283283
]
284284

285285
for split_input, binary_op in itertools.product([True, False], binary_op_list):

benchmarks/tensorexpr/elementwise.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,19 @@ def register_element_ops():
122122
["div", lambda a, b: a / (b + 1e-4)],
123123
[
124124
"pow",
125-
lambda a, b: torch.pow(a, b),
126-
lambda a, b: np.power(a, b),
125+
torch.pow,
126+
np.power,
127127
], # no fuson triggered
128-
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
129-
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
128+
["max", torch.max, np.maximum],
129+
["min", torch.min, np.minimum],
130130
]
131131

132132
unary_op_list = [
133-
["erf", lambda x: torch.erf(x), lambda x: scipy.special.erf(x)],
134-
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
135-
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
136-
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
137-
["rand_like", lambda x: torch.rand_like(x), lambda x: np.random.rand(*x.shape)],
133+
["erf", torch.erf, scipy.special.erf],
134+
["exp", torch.exp, np.exp],
135+
["sin", torch.sin, np.sin],
136+
["cos", torch.cos, np.cos],
137+
["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
138138
]
139139

140140
for split_input, binary_op in itertools.product([True, False], binary_op_list):

caffe2/python/memonger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def _find_best(ranges, init_assignment, prev_best_assignment, counter):
648648
# Try to put 'find_range' in a new assignment
649649
best_candidates.append(prev_best_assignment + [[find_range]])
650650

651-
ret = min(best_candidates, key=lambda x: get_memory_usage(x))
651+
ret = min(best_candidates, key=get_memory_usage)
652652
return ret
653653

654654
if not counter:

caffe2/python/net_drawer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def main():
374374
content = fid.read()
375375
graphs = utils.GetContentFromProtoString(
376376
content, {
377-
caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
377+
caffe2_pb2.PlanDef: GetOperatorMapForPlan,
378378
caffe2_pb2.NetDef: lambda x: {x.name: x.op},
379379
}
380380
)

caffe2/python/parallel_workers_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def testParallelWorkers(self):
6060
workspace.ResetWorkspace()
6161

6262
queue = create_queue()
63-
dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
63+
dummy_worker = create_worker(queue, str)
6464
worker_coordinator = parallel_workers.init_workers(dummy_worker)
6565
worker_coordinator.start()
6666

@@ -102,7 +102,7 @@ def testParallelWorkersShutdownFun(self):
102102
workspace.ResetWorkspace()
103103

104104
queue = create_queue()
105-
dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
105+
dummy_worker = create_worker(queue, str)
106106
workspace.FeedBlob('data', 'not shutdown')
107107

108108
def shutdown_fun():

tools/autograd/gen_python_functions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def create_python_bindings(
405405

406406
grouped = group_filter_overloads(pairs, pred)
407407

408-
for name in sorted(grouped.keys(), key=lambda x: str(x)):
408+
for name in sorted(grouped.keys(), key=str):
409409
overloads = grouped[name]
410410
py_methods.append(
411411
method_impl(name, module, overloads, method=method, symint=symint)
@@ -443,7 +443,7 @@ def create_python_return_type_bindings(
443443

444444
grouped = group_filter_overloads(pairs, pred)
445445

446-
for name in sorted(grouped.keys(), key=lambda x: str(x)):
446+
for name in sorted(grouped.keys(), key=str):
447447
overloads = grouped[name]
448448
definitions, registrations = generate_return_type_definition_and_registrations(
449449
overloads
@@ -481,7 +481,7 @@ def create_python_return_type_bindings_header(
481481

482482
grouped = group_filter_overloads(pairs, pred)
483483

484-
for name in sorted(grouped.keys(), key=lambda x: str(x)):
484+
for name in sorted(grouped.keys(), key=str):
485485
overloads = grouped[name]
486486
declarations = generate_return_type_declarations(overloads)
487487
py_return_types_declarations.append(

torch/_custom_op/autograd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def forward(ctx, *flat_args):
108108

109109
# We use the info about args to give better error messages in backward
110110
args_info = namedtuple_args(
111-
schema, pytree.tree_map(lambda arg: type(arg), args))
111+
schema, pytree.tree_map(type, args))
112112

113113
save_for_backward_fn_inputs = namedtuple_args(schema, args)
114114
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)

torch/_dynamo/guards.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def uninteresting_files():
114114
"___tuple_iterator_getitem": tuple_iterator_getitem,
115115
"__math_isnan": math.isnan,
116116
"inf": float("inf"),
117-
"__load_module": lambda name: importlib.import_module(name),
117+
"__load_module": importlib.import_module,
118118
"utils_device": torch.utils._device,
119119
"device": torch.device,
120120
"___from_numpy":

torch/_dynamo/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,7 @@ class CompileProfiler:
13251325
def __init__(self):
13261326
self.frame_count = 0
13271327
self.op_count = 0
1328-
self.backend_ctx_ctor = lambda: disable_cache_limit()
1328+
self.backend_ctx_ctor = disable_cache_limit
13291329

13301330
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
13311331
self.frame_count += 1

torch/_export/db/case.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ExportCase:
7979
name: str
8080
extra_inputs: Optional[InputsType] = None # For testing graph generalization.
8181
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
82-
tags: Set[str] = field(default_factory=lambda: set())
82+
tags: Set[str] = field(default_factory=set)
8383
support_level: SupportLevel = SupportLevel.SUPPORTED
8484
dynamic_shapes: Optional[Dict[str, Any]] = None
8585

torch/_export/unflatten.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def finalize_outputs(self):
464464

465465
def copy_node(self, node):
466466
self.print("copying", node.format_node())
467-
self.node_map[node] = self.graph.node_copy(node, lambda n: self.remap_input(n))
467+
self.node_map[node] = self.graph.node_copy(node, self.remap_input)
468468
self.seen_nodes[node.name] = node
469469

470470
def run_outer(self):

torch/_functorch/partitioners.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def get_node_weight(node) -> int:
886886
node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
887887
saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
888888
# save_for_backward on tensors and stashes symints in autograd .ctx
889-
saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values))
889+
saved_sym_nodes = list(filter(is_sym_node, saved_values))
890890
saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
891891
# NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
892892
fw_module, bw_module = _extract_fwd_bwd_modules(

torch/_inductor/autotune_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def initialize(self) -> None:
243243
EXIT_HANDLER_REGISTERED = True
244244
import atexit
245245

246-
atexit.register(lambda: self.terminate())
246+
atexit.register(self.terminate)
247247

248248
def get_device_list(self) -> List[Optional[int]]:
249249
"""

torch/_inductor/codegen/cpp.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2720,9 +2720,7 @@ def select_tiling_indices():
27202720
and contig_vars_sorted[-1] == len(self.itervars) - 1
27212721
):
27222722
return contig_vars_sorted
2723-
return sorted(contig_vars_sorted, key=lambda i: contig_vars_list.count(i))[
2724-
-1:
2725-
]
2723+
return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:]
27262724

27272725
def select_tiling(dtype: torch.dtype = torch.float):
27282726
# TODO(jgong5): support alternative tiling factors and data types

torch/_inductor/graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def set_current_node(self, node: torch.fx.Node):
715715

716716
def run_node(self, n: torch.fx.Node):
717717
def debug(msg):
718-
log.debug("lowering %s %s", LazyString(lambda: n.format_node()), msg)
718+
log.debug("lowering %s %s", LazyString(n.format_node), msg)
719719

720720
origins = {n}
721721
if n.op == "call_function":

torch/_inductor/inductor_prims.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
7171
)
7272
force_stride_order = make_prim(
7373
"inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
74-
lambda input_tensor, stride: eager_force_stride(input_tensor, stride),
74+
eager_force_stride,
7575
doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
7676
)
7777
masked_scatter_with_index = make_prim(

torch/_inductor/kernel/mm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,6 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
300300
layout=layout,
301301
**dict(mm_options(config, k, layout), **{"ACC_TYPE": "tl.int32"}),
302302
suffix_args=1,
303-
epilogue_fn=lambda acc, mat3: V.ops.mul(acc, mat3),
303+
epilogue_fn=V.ops.mul,
304304
)
305305
return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)

torch/_inductor/lowering.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1968,21 +1968,21 @@ def inner_fn(index):
19681968

19691969
def require_dense(_, *args, **kwargs):
19701970
args, kwargs = pytree.tree_map_only(
1971-
ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
1971+
ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
19721972
)
19731973
return args, kwargs
19741974

19751975

19761976
def require_contiguous(_, *args, **kwargs):
19771977
args, kwargs = pytree.tree_map_only(
1978-
ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs)
1978+
ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
19791979
)
19801980
return args, kwargs
19811981

19821982

19831983
def require_channels_last(_, *args, **kwargs):
19841984
args, kwargs = pytree.tree_map_only(
1985-
ir.IRNode, lambda t: ir.ExternKernel.require_channels_last(t), (args, kwargs)
1985+
ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
19861986
)
19871987
return args, kwargs
19881988

torch/_inductor/triton_heuristics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def __init__(self, *args, regex_filter="", **kwargs):
608608

609609
def run(self, *args, grid, stream):
610610
possible_names = _find_names(self)
611-
kernel_name = f"{max(possible_names, key=lambda x: len(x))}"
611+
kernel_name = f"{max(possible_names, key=len)}"
612612
if not re.match(self.regex_filter, kernel_name):
613613
return
614614
super().run(*args, grid=grid, stream=stream)

torch/_prims/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ def _greater_than_reduce(acc, x):
12311231

12321232
return x
12331233

1234-
reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
1234+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
12351235

12361236
# shape must be broadcastable to
12371237
for idx, new_idx in enumerate(broadcast_dimensions):

torch/distributed/_spmd/data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def gradients_tagging(params: Dict[str, torch.Tensor]):
120120
tagging_hooks = []
121121
try:
122122
for p in params.values():
123-
h = p.register_hook(lambda grad: torch.ops._spmd.tag_grad(grad))
123+
h = p.register_hook(torch.ops._spmd.tag_grad)
124124
tagging_hooks.append(h)
125125
yield
126126
finally:

torch/distributed/checkpoint/_fsspec_filesystem.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,12 @@ def _write_files_from_queue(
268268

269269
if torch.cuda.is_available() and inflight_threshhold > 0:
270270
loader = _OverlappingCpuLoader(
271-
lambda x: planner.resolve_data(x),
271+
planner.resolve_data,
272272
inflight_threshhold=inflight_threshhold,
273273
)
274274
else:
275275
loader = _SerialCpuLoader(
276-
lambda x: planner.resolve_data(x),
276+
planner.resolve_data,
277277
)
278278

279279
tensor_w = [

torch/distributed/checkpoint/filesystem.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,12 @@ def _write_files_from_queue(
265265

266266
if torch.cuda.is_available() and inflight_threshhold > 0:
267267
loader = _OverlappingCpuLoader(
268-
lambda x: planner.resolve_data(x),
268+
planner.resolve_data,
269269
inflight_threshhold=inflight_threshhold,
270270
)
271271
else:
272272
loader = _SerialCpuLoader(
273-
lambda x: planner.resolve_data(x),
273+
planner.resolve_data,
274274
)
275275

276276
tensor_w = [

torch/distributed/fsdp/_init_utils.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -345,18 +345,14 @@ def _check_ignored_states(
345345
all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
346346
if not all_params and not all_modules:
347347
# Sort for consistent ordering for unit test regex matching
348-
sorted_types = sorted(
349-
{type(state) for state in ignored_states}, key=lambda x: repr(x)
350-
)
348+
sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
351349
raise ValueError(
352350
"ignored_states expects all nn.Parameter or all nn.Module list "
353351
f"elements but got types {sorted_types}"
354352
)
355353
else:
356354
if not all(isinstance(state, nn.Module) for state in ignored_states):
357-
sorted_types = sorted(
358-
{type(state) for state in ignored_states}, key=lambda x: repr(x)
359-
)
355+
sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
360356
raise ValueError(
361357
"ignored_modules expects nn.Module list elements but got "
362358
f"types {sorted_types}"

torch/distributed/fsdp/fully_sharded_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def clip_grad_norm_(
11851185
) # warn since this is generally unexpected
11861186
return total_norm
11871187
total_norm_dtype = functools.reduce(
1188-
lambda dtype1, dtype2: torch.promote_types(dtype1, dtype2),
1188+
torch.promote_types,
11891189
[grad.dtype for grad in grads],
11901190
)
11911191
return total_norm.to(total_norm_dtype)

torch/distributed/rpc/rref_proxy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _complete_op(fut):
5959
except BaseException as ex:
6060
result.set_exception(ex)
6161

62-
rref_fut.then(lambda fut: _wrap_rref_type_cont(fut))
62+
rref_fut.then(_wrap_rref_type_cont)
6363
return result
6464

6565
# This class manages proxied RPC API calls for RRefs. It is entirely used from

torch/fx/experimental/accelerator_partitioner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,8 @@ def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
583583
if node.target == operator.__getitem__:
584584
continue
585585
input_nodes: Dict[Node, None] = {}
586-
map_arg(node.args, lambda n: input_nodes.setdefault(n))
587-
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
586+
map_arg(node.args, input_nodes.setdefault)
587+
map_arg(node.kwargs, input_nodes.setdefault)
588588
# When a node has two or more output nodes,
589589
# it outputs its result to 'getitem' nodes.
590590
# Those 'getitem' nodes are the output node for this node.

0 commit comments

Comments
 (0)