Skip to content

Commit 07caea5

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Refactor COMPARE_OP and comparison builtins (pytorch#122043)
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure. This change regresses overheads a bit, but this is fixed in the next PR. New test skips are variants of `type(e) is np.ndarray` previously falling back to eager. Pull Request resolved: pytorch#122043 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#122039
1 parent 769ff86 commit 07caea5

13 files changed

+305
-243
lines changed

benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
8686

8787

8888

89-
detectron2_fcos_r_50_fpn,pass,35
89+
detectron2_fcos_r_50_fpn,pass,94
9090

9191

9292

benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv

+11-11
Original file line numberDiff line numberDiff line change
@@ -54,47 +54,47 @@ densenet121,pass,0
5454

5555

5656

57-
detectron2_fasterrcnn_r_101_c4,pass,51
57+
detectron2_fasterrcnn_r_101_c4,pass,164
5858

5959

6060

61-
detectron2_fasterrcnn_r_101_dc5,pass,51
61+
detectron2_fasterrcnn_r_101_dc5,pass,163
6262

6363

6464

65-
detectron2_fasterrcnn_r_101_fpn,pass,55
65+
detectron2_fasterrcnn_r_101_fpn,pass,172
6666

6767

6868

69-
detectron2_fasterrcnn_r_50_c4,pass,51
69+
detectron2_fasterrcnn_r_50_c4,pass,113
7070

7171

7272

73-
detectron2_fasterrcnn_r_50_dc5,pass,51
73+
detectron2_fasterrcnn_r_50_dc5,pass,112
7474

7575

7676

77-
detectron2_fasterrcnn_r_50_fpn,pass,55
77+
detectron2_fasterrcnn_r_50_fpn,pass,121
7878

7979

8080

81-
detectron2_fcos_r_50_fpn,pass,38
81+
detectron2_fcos_r_50_fpn,pass,97
8282

8383

8484

85-
detectron2_maskrcnn_r_101_c4,fail_accuracy,66
85+
detectron2_maskrcnn_r_101_c4,pass,182
8686

8787

8888

89-
detectron2_maskrcnn_r_101_fpn,pass,73
89+
detectron2_maskrcnn_r_101_fpn,pass,192
9090

9191

9292

93-
detectron2_maskrcnn_r_50_c4,pass,66
93+
detectron2_maskrcnn_r_50_c4,pass,131
9494

9595

9696

97-
detectron2_maskrcnn_r_50_fpn,pass,73
97+
detectron2_maskrcnn_r_50_fpn,pass,141
9898

9999

100100

benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
8686

8787

8888

89-
detectron2_fcos_r_50_fpn,pass,35
89+
detectron2_fcos_r_50_fpn,pass,94
9090

9191

9292

benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ densenet121,pass,0
5454

5555

5656

57-
detectron2_fcos_r_50_fpn,pass,38
57+
detectron2_fcos_r_50_fpn,pass,97
5858

5959

6060

benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
8686

8787

8888

89-
detectron2_fcos_r_50_fpn,pass,36
89+
detectron2_fcos_r_50_fpn,pass,95
9090

9191

9292

benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
8686

8787

8888

89-
detectron2_fcos_r_50_fpn,pass,35
89+
detectron2_fcos_r_50_fpn,pass,94
9090

9191

9292

benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
8686

8787

8888

89-
detectron2_fcos_r_50_fpn,pass,36
89+
detectron2_fcos_r_50_fpn,pass,95
9090

9191

9292

test/dynamo/test_functions.py

+20
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,26 @@ def test_partials_udf_arg(x):
14591459
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
14601460
return par_mul(x)
14611461

1462+
@make_test
1463+
def test_list_add_then_mutate(x):
1464+
my_list = [1, x]
1465+
y = x / 4.0
1466+
my_list = my_list + [x / 2.0, 4]
1467+
my_list.append(y)
1468+
return sum(my_list)
1469+
1470+
@make_test
1471+
def test_list_expand_lhs(x):
1472+
return sum(4 * [x])
1473+
1474+
@make_test
1475+
def test_in_not_in(x):
1476+
mylist = [1, 2, 3, 4, 5, x]
1477+
myotherlist = [1, 2, 3, 4, 5]
1478+
assert 3 in mylist
1479+
assert 6 not in myotherlist
1480+
return sum(mylist)
1481+
14621482
@make_test
14631483
def test_partials_udf_kwarg(x):
14641484
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))

test/dynamo_skips/TestSqueeze.test_squeeze_type

Whitespace-only changes.

test/dynamo_skips/TestSubscripting.test_test_zero_rank

Whitespace-only changes.

torch/_dynamo/symbolic_convert.py

+13-50
Original file line numberDiff line numberDiff line change
@@ -99,24 +99,29 @@
9999
UnknownVariable,
100100
)
101101
from .variables.nn_module import NNModuleVariable
102-
from .variables.tensor import (
103-
supported_comparison_ops,
104-
supported_const_comparison_ops,
105-
SymNodeVariable,
106-
TensorVariable,
107-
)
102+
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
108103
from .variables.user_defined import (
109104
RemovableHandleVariable,
110105
UserDefinedClassVariable,
111106
UserDefinedObjectVariable,
112-
UserDefinedVariable,
113107
)
114108

115109
log = logging.getLogger(__name__)
116110
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
117111
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
118112
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
119113
tls = threading.local()
114+
compare_op_handlers: Dict[str, Any] = {
115+
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
116+
}
117+
handle_contains = BuiltinVariable(operator.contains).call_function
118+
handle_not = BuiltinVariable(operator.not_).call_function
119+
compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
120+
tx, [*reversed(args)], {}
121+
)
122+
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
123+
tx, [handle_contains(tx, [*reversed(args)], {})], {}
124+
)
120125

121126

122127
@dataclasses.dataclass
@@ -1200,49 +1205,7 @@ def FOR_ITER(self, inst):
12001205
unimplemented(f"FOR_ITER {typestr(it)}")
12011206

12021207
def COMPARE_OP(self, inst):
1203-
left, right = self.popn(2)
1204-
op = inst.argval
1205-
if op == "in" or op == "not in":
1206-
self.push(right.call_method(self, "__contains__", [left], {}))
1207-
if op == "not in":
1208-
self.UNARY_NOT(inst)
1209-
return
1210-
1211-
if right.is_python_constant():
1212-
if left.is_python_constant():
1213-
# constant fold
1214-
return self.push(
1215-
ConstantVariable(
1216-
supported_comparison_ops[op](
1217-
left.as_python_constant(), right.as_python_constant()
1218-
),
1219-
)
1220-
)
1221-
elif (
1222-
op in supported_const_comparison_ops
1223-
and right.as_python_constant() is None
1224-
and isinstance(
1225-
left,
1226-
(
1227-
TensorVariable,
1228-
SymNodeVariable,
1229-
NNModuleVariable,
1230-
BaseListVariable,
1231-
UserDefinedVariable,
1232-
BaseUserFunctionVariable,
1233-
ConstDictVariable,
1234-
),
1235-
)
1236-
):
1237-
# <non-None> is None
1238-
return self.push(
1239-
ConstantVariable(supported_const_comparison_ops[op](object(), None))
1240-
)
1241-
self.push(
1242-
BuiltinVariable(supported_comparison_ops[op]).call_function(
1243-
self, [left, right], {}
1244-
)
1245-
)
1208+
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
12461209

12471210
def GET_ITER(self, inst):
12481211
self.call_function(BuiltinVariable(iter), [self.pop()], {})

0 commit comments

Comments
 (0)