Skip to content

Commit 769ff86

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Optimize COMPARE_OP (pytorch#122039)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py` from 5.6 to 5.1s. Pull Request resolved: pytorch#122039 Approved by: https://github.com/Skylion007, https://github.com/anijain2305
1 parent e1706bb commit 769ff86

File tree

3 files changed

+56
-57
lines changed

3 files changed

+56
-57
lines changed

torch/_dynamo/symbolic_convert.py

+38-51
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@
100100
)
101101
from .variables.nn_module import NNModuleVariable
102102
from .variables.tensor import (
103+
supported_comparison_ops,
103104
supported_const_comparison_ops,
104-
supported_tensor_comparison_ops,
105105
SymNodeVariable,
106106
TensorVariable,
107107
)
@@ -880,8 +880,7 @@ def pop(self) -> VariableTracker:
880880
return self.stack.pop()
881881

882882
def popn(self, n: int) -> List[VariableTracker]:
883-
assert n >= 0
884-
return list(reversed([self.pop() for _ in range(n)]))
883+
return [*reversed([self.pop() for _ in range(n)])]
885884

886885
def LOAD_FAST(self, inst):
887886
name = inst.argval
@@ -1203,59 +1202,47 @@ def FOR_ITER(self, inst):
12031202
def COMPARE_OP(self, inst):
12041203
left, right = self.popn(2)
12051204
op = inst.argval
1206-
supported_any = dict(
1207-
itertools.chain(
1208-
supported_tensor_comparison_ops.items(),
1209-
supported_const_comparison_ops.items(),
1210-
)
1211-
)
1212-
if (
1213-
isinstance(
1214-
left,
1215-
(
1216-
TensorVariable,
1217-
SymNodeVariable,
1218-
NNModuleVariable,
1219-
BaseListVariable,
1220-
UserDefinedVariable,
1221-
BaseUserFunctionVariable,
1222-
ConstDictVariable,
1223-
),
1224-
)
1225-
and isinstance(right, ConstantVariable)
1226-
and right.value is None
1227-
and op in supported_const_comparison_ops
1228-
):
1229-
# <non-None> is None
1230-
self.push(
1231-
ConstantVariable.create(
1232-
supported_const_comparison_ops[op](object(), right.value)
1233-
)
1234-
)
1235-
1236-
elif (
1237-
left.is_python_constant()
1238-
and right.is_python_constant()
1239-
and op in supported_any
1240-
):
1241-
# constant fold
1242-
self.push(
1243-
ConstantVariable.create(
1244-
supported_any[op](
1245-
left.as_python_constant(), right.as_python_constant()
1246-
),
1247-
)
1248-
)
1249-
elif op in ("in", "not in"):
1205+
if op == "in" or op == "not in":
12501206
self.push(right.call_method(self, "__contains__", [left], {}))
12511207
if op == "not in":
12521208
self.UNARY_NOT(inst)
1253-
else:
1254-
self.push(
1255-
BuiltinVariable(supported_any[op]).call_function(
1256-
self, [left, right], {}
1209+
return
1210+
1211+
if right.is_python_constant():
1212+
if left.is_python_constant():
1213+
# constant fold
1214+
return self.push(
1215+
ConstantVariable(
1216+
supported_comparison_ops[op](
1217+
left.as_python_constant(), right.as_python_constant()
1218+
),
1219+
)
12571220
)
1221+
elif (
1222+
op in supported_const_comparison_ops
1223+
and right.as_python_constant() is None
1224+
and isinstance(
1225+
left,
1226+
(
1227+
TensorVariable,
1228+
SymNodeVariable,
1229+
NNModuleVariable,
1230+
BaseListVariable,
1231+
UserDefinedVariable,
1232+
BaseUserFunctionVariable,
1233+
ConstDictVariable,
1234+
),
1235+
)
1236+
):
1237+
# <non-None> is None
1238+
return self.push(
1239+
ConstantVariable(supported_const_comparison_ops[op](object(), None))
1240+
)
1241+
self.push(
1242+
BuiltinVariable(supported_comparison_ops[op]).call_function(
1243+
self, [left, right], {}
12581244
)
1245+
)
12591246

12601247
def GET_ITER(self, inst):
12611248
self.call_function(BuiltinVariable(iter), [self.pop()], {})

torch/_dynamo/variables/builtin.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1526,8 +1526,8 @@ def _comparison(self, tx, left, right):
15261526
)
15271527
from .lists import SizeVariable
15281528
from .tensor import (
1529-
supported_const_comparison_ops,
1530-
supported_tensor_comparison_ops,
1529+
supported_const_comparison_op_values,
1530+
supported_tensor_comparison_op_values,
15311531
)
15321532

15331533
op = self.fn
@@ -1540,7 +1540,7 @@ def _unimplemented():
15401540
isinstance(x, (NNModuleVariable, ConstantVariable))
15411541
for x in [left, right]
15421542
)
1543-
and op in supported_const_comparison_ops.values()
1543+
and op in supported_const_comparison_op_values
15441544
):
15451545
left = (
15461546
tx.output.get_submodule(left.module_key)
@@ -1555,7 +1555,7 @@ def _unimplemented():
15551555
return ConstantVariable.create(op(left, right))
15561556

15571557
if isinstance(left, UserFunctionVariable):
1558-
if op not in supported_const_comparison_ops.values():
1558+
if op not in supported_const_comparison_op_values:
15591559
_unimplemented()
15601560
if not isinstance(right, UserFunctionVariable):
15611561
_unimplemented()
@@ -1594,7 +1594,7 @@ def _unimplemented():
15941594
else:
15951595
return ConstantVariable.create(not is_result)
15961596

1597-
if op not in supported_tensor_comparison_ops.values():
1597+
if op not in supported_tensor_comparison_op_values:
15981598
_unimplemented()
15991599
if (
16001600
isinstance(left, TensorVariable)
@@ -1618,7 +1618,7 @@ def _unimplemented():
16181618
)
16191619

16201620
if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
1621-
if op not in supported_tensor_comparison_ops.values():
1621+
if op not in supported_tensor_comparison_op_values:
16221622
_unimplemented()
16231623

16241624
proxy = tx.output.create_proxy(

torch/_dynamo/variables/tensor.py

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from .constant import ConstantVariable
5858
from .lists import SizeVariable
5959

60+
# Ops that allow tensor <op> tensor
6061
supported_tensor_comparison_ops = {
6162
">": operator.gt,
6263
"<": operator.lt,
@@ -65,12 +66,23 @@
6566
"==": operator.eq,
6667
"!=": operator.ne,
6768
}
69+
# Ops that allow tensor <op> None
6870
supported_const_comparison_ops = {
6971
"is": operator.is_,
7072
"is not": operator.is_not,
7173
"==": operator.eq,
7274
"!=": operator.ne,
7375
}
76+
supported_comparison_ops = {
77+
**supported_tensor_comparison_ops,
78+
**supported_const_comparison_ops,
79+
}
80+
supported_tensor_comparison_op_values = dict.fromkeys(
81+
supported_tensor_comparison_ops.values()
82+
)
83+
supported_const_comparison_op_values = dict.fromkeys(
84+
supported_const_comparison_ops.values()
85+
)
7486

7587

7688
class TensorVariable(VariableTracker):

0 commit comments

Comments
 (0)