|
99 | 99 | UnknownVariable,
|
100 | 100 | )
|
101 | 101 | from .variables.nn_module import NNModuleVariable
|
102 |
| -from .variables.tensor import ( |
103 |
| - supported_comparison_ops, |
104 |
| - supported_const_comparison_ops, |
105 |
| - SymNodeVariable, |
106 |
| - TensorVariable, |
107 |
| -) |
| 102 | +from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable |
108 | 103 | from .variables.user_defined import (
|
109 | 104 | RemovableHandleVariable,
|
110 | 105 | UserDefinedClassVariable,
|
111 | 106 | UserDefinedObjectVariable,
|
112 |
| - UserDefinedVariable, |
113 | 107 | )
|
114 | 108 |
|
115 | 109 | log = logging.getLogger(__name__)
|
116 | 110 | graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
117 | 111 | trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
118 | 112 | trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
|
119 | 113 | tls = threading.local()
|
| 114 | +compare_op_handlers: Dict[str, Any] = { |
| 115 | + k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items() |
| 116 | +} |
| 117 | +handle_contains = BuiltinVariable(operator.contains).call_function |
| 118 | +handle_not = BuiltinVariable(operator.not_).call_function |
| 119 | +compare_op_handlers["in"] = lambda tx, args, _: handle_contains( |
| 120 | + tx, [*reversed(args)], {} |
| 121 | +) |
| 122 | +compare_op_handlers["not in"] = lambda tx, args, _: handle_not( |
| 123 | + tx, [handle_contains(tx, [*reversed(args)], {})], {} |
| 124 | +) |
120 | 125 |
|
121 | 126 |
|
122 | 127 | @dataclasses.dataclass
|
@@ -1200,49 +1205,7 @@ def FOR_ITER(self, inst):
|
1200 | 1205 | unimplemented(f"FOR_ITER {typestr(it)}")
|
1201 | 1206 |
|
1202 | 1207 | def COMPARE_OP(self, inst):
|
1203 |
| - left, right = self.popn(2) |
1204 |
| - op = inst.argval |
1205 |
| - if op == "in" or op == "not in": |
1206 |
| - self.push(right.call_method(self, "__contains__", [left], {})) |
1207 |
| - if op == "not in": |
1208 |
| - self.UNARY_NOT(inst) |
1209 |
| - return |
1210 |
| - |
1211 |
| - if right.is_python_constant(): |
1212 |
| - if left.is_python_constant(): |
1213 |
| - # constant fold |
1214 |
| - return self.push( |
1215 |
| - ConstantVariable( |
1216 |
| - supported_comparison_ops[op]( |
1217 |
| - left.as_python_constant(), right.as_python_constant() |
1218 |
| - ), |
1219 |
| - ) |
1220 |
| - ) |
1221 |
| - elif ( |
1222 |
| - op in supported_const_comparison_ops |
1223 |
| - and right.as_python_constant() is None |
1224 |
| - and isinstance( |
1225 |
| - left, |
1226 |
| - ( |
1227 |
| - TensorVariable, |
1228 |
| - SymNodeVariable, |
1229 |
| - NNModuleVariable, |
1230 |
| - BaseListVariable, |
1231 |
| - UserDefinedVariable, |
1232 |
| - BaseUserFunctionVariable, |
1233 |
| - ConstDictVariable, |
1234 |
| - ), |
1235 |
| - ) |
1236 |
| - ): |
1237 |
| - # <non-None> is None |
1238 |
| - return self.push( |
1239 |
| - ConstantVariable(supported_const_comparison_ops[op](object(), None)) |
1240 |
| - ) |
1241 |
| - self.push( |
1242 |
| - BuiltinVariable(supported_comparison_ops[op]).call_function( |
1243 |
| - self, [left, right], {} |
1244 |
| - ) |
1245 |
| - ) |
| 1208 | + self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) |
1246 | 1209 |
|
1247 | 1210 | def GET_ITER(self, inst):
|
1248 | 1211 | self.call_function(BuiltinVariable(iter), [self.pop()], {})
|
|
0 commit comments