Skip to content

Commit 44257c0

Browse files
recpytorchmergebot
authored andcommitted
[dynamo] Fix constant propagation in builtins and UserClasses (pytorch#131354)
* Fixes pytorch#118675 * Replaces pytorch#118994 Pull Request resolved: pytorch#131354 Approved by: https://github.com/jansel, https://github.com/anijain2305
1 parent a951d99 commit 44257c0

12 files changed

+91
-47
lines changed

test/dynamo/test_higher_order_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3818,10 +3818,10 @@ def wrapper_fn(model, params, buffers, inputs):
38183818
if torch._dynamo.config.inline_inbuilt_nn_modules:
38193819
expected = """\
38203820
class GraphModule(torch.nn.Module):
3821-
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
3821+
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_buffers_buffer_: "f32[1]", L_params_l1_bias_: "f32[1]", L_inputs_: "f32[1, 1]"):
38223822
l_params_l1_weight_ = L_params_l1_weight_
3823-
l_params_l1_bias_ = L_params_l1_bias_
38243823
l_buffers_buffer_ = L_buffers_buffer_
3824+
l_params_l1_bias_ = L_params_l1_bias_
38253825
l_inputs_ = L_inputs_
38263826
38273827
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
@@ -6005,7 +6005,7 @@ def wrapper_fn(x, y):
60056005
return torch.func.vmap(f)(x, y)
60066006

60076007
actual = wrapper_fn(x, y)
6008-
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
6008+
expected = torch.compile(wrapper_fn, backend="aot_eager")(x, y)
60096009
self.assertEqual(len(counters["graph_break"]), 0)
60106010
self.assertEqual(actual, expected)
60116011
self.assertEqual(some_list, [1, 1])

test/dynamo/test_misc.py

+28
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@
101101
TPFLAGS_MAPPING = 1 << 6
102102

103103

104+
# A class defined in the global scope, used in MiscTests.test_const_getattr
105+
class _B:
106+
def __init__(self):
107+
pass
108+
109+
104110
# Specializes a test to run only if translation validation is set.
105111
def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
106112
@functools.wraps(fn)
@@ -1412,6 +1418,28 @@ def fn(x, s):
14121418
# One recompile per differing input type
14131419
self.assertEqual(cnts.frame_count, 3)
14141420

1421+
def test_const_getattr(self):
1422+
# See https://github.com/pytorch/pytorch/issues/118675
1423+
def fn(x):
1424+
y = x[f"{_B.__module__}.{_B.__name__}"]
1425+
z = x[f"{_B.__class__.__module__}.{_B.__name__}"]
1426+
u = x[f"{_B.__class__.__module__}.{_B.__class__.__qualname__}"]
1427+
return y + z + u
1428+
1429+
args = (
1430+
{
1431+
f"{_B.__module__}._B": torch.randn(10),
1432+
"builtins._B": torch.randn(10),
1433+
"builtins.type": torch.randn(10),
1434+
},
1435+
)
1436+
1437+
cnts = torch._dynamo.testing.CompileCounter()
1438+
opt_fn = torch._dynamo.optimize(cnts)(fn)
1439+
1440+
self.assertEqual(fn(*args), opt_fn(*args))
1441+
self.assertEqual(cnts.frame_count, 1)
1442+
14151443
def test_cell_output1(self):
14161444
out = None
14171445

test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok

Whitespace-only changes.

torch/_dynamo/guards.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def uninteresting_files():
315315
_CLOSURE_VARS: Optional[Dict[str, object]] = None
316316

317317

318-
def _get_closure_vars():
318+
def get_closure_vars():
319319
global _CLOSURE_VARS
320320
if _CLOSURE_VARS is None:
321321
_CLOSURE_VARS = {
@@ -1150,7 +1150,7 @@ def add_python_lambda_leaf_guard_to_root(
11501150
is_epilogue=True,
11511151
):
11521152
if closure_vars is None:
1153-
closure_vars = _get_closure_vars()
1153+
closure_vars = get_closure_vars()
11541154
# Adds a lambda leaf guard to the root guard manager. It wraps the
11551155
# code_parts in a function object which is then passed on to the leaf
11561156
# guard.
@@ -1177,7 +1177,7 @@ def add_python_lambda_leaf_guard_to_root(
11771177
# (like its type) which is what you permanently install into the
11781178
# guard code.
11791179
def get(self, name: str) -> Any:
1180-
return eval(name, self.scope, _get_closure_vars())
1180+
return eval(name, self.scope, get_closure_vars())
11811181

11821182
# Registers the usage of the source name referenced by the
11831183
# string (or stored in the Guard) as being guarded upon. It's important
@@ -1497,7 +1497,7 @@ def EQUALS_MATCH(self, guard: Guard):
14971497
self._set_guard_export_info(guard, code)
14981498

14991499
self.get_guard_manager(guard).add_lambda_guard(
1500-
_get_closure_vars()["__math_isnan"],
1500+
get_closure_vars()["__math_isnan"],
15011501
get_verbose_code_parts(code, guard),
15021502
)
15031503
return
@@ -1510,7 +1510,7 @@ def EQUALS_MATCH(self, guard: Guard):
15101510
self._set_guard_export_info(guard, code)
15111511

15121512
self.get_guard_manager(guard).add_lambda_guard(
1513-
_get_closure_vars()["__numpy_isnan"],
1513+
get_closure_vars()["__numpy_isnan"],
15141514
get_verbose_code_parts(code, guard),
15151515
)
15161516
return
@@ -1786,7 +1786,7 @@ def get_sources(t_id, dim):
17861786
self.add_python_lambda_leaf_guard_to_root(
17871787
code_parts,
17881788
verbose_code_parts,
1789-
closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
1789+
closure_vars={**SYMPY_INTERP, **get_closure_vars()},
17901790
)
17911791

17921792
def TENSOR_MATCH(self, guard: Guard, value=None):
@@ -2375,7 +2375,7 @@ def add_code_part(code_part, guard, log_only=False):
23752375
"___check_global_state": global_state.check,
23762376
"___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
23772377
**SYMPY_INTERP,
2378-
**_get_closure_vars(),
2378+
**get_closure_vars(),
23792379
}
23802380

23812381
globals_for_guard_fn = {"G": builder.scope["G"]}

torch/_dynamo/variables/base.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from .. import variables
88
from ..current_scope_id import current_scope_id
9-
from ..exc import unimplemented
9+
from ..exc import unimplemented, Unsupported
1010
from ..source import AttrSource, Source
11-
from ..utils import istype
11+
from ..utils import is_function_or_wrapper, istype
1212

1313

1414
if TYPE_CHECKING:
@@ -238,16 +238,24 @@ def make_guard(self, fn):
238238
raise NotImplementedError
239239

240240
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
241-
"""getattr(self, name) returning a python constant"""
242-
raise NotImplementedError
241+
v = self.as_python_constant()
242+
try:
243+
return getattr(v, name)
244+
except AttributeError:
245+
raise NotImplementedError from None
243246

244247
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
245248
"""getattr(self, name) returning a new variable"""
246-
value = self.const_getattr(tx, name)
247-
if not variables.ConstantVariable.is_literal(value):
248-
raise NotImplementedError
249+
from .misc import GetAttrVariable
250+
249251
source = self.source and AttrSource(self.source, name)
250-
return variables.ConstantVariable.create(value, source=source)
252+
try:
253+
value = self.const_getattr(tx, name)
254+
if not is_function_or_wrapper(value):
255+
return VariableTracker.build(tx, value, source)
256+
except (NotImplementedError, Unsupported):
257+
pass
258+
return GetAttrVariable(self, name, source=source)
251259

252260
def is_proxy(self):
253261
try:

torch/_dynamo/variables/builtin.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1714,14 +1714,10 @@ def call_getattr(
17141714
if config.replay_record_enabled:
17151715
tx.exec_recorder.record_module_access(obj.value, name, member)
17161716
return VariableTracker.build(tx, member, source)
1717-
17181717
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
17191718
return ConstantVariable.create(getattr(obj.fn, name))
17201719
else:
1721-
try:
1722-
return obj.var_getattr(tx, name)
1723-
except NotImplementedError:
1724-
return GetAttrVariable(obj, name, source=source)
1720+
return obj.var_getattr(tx, name)
17251721

17261722
def call_setattr(
17271723
self,

torch/_dynamo/variables/distributed.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: ignore-errors
22
import functools
33
import inspect
4-
from typing import Dict, List, TYPE_CHECKING
4+
from typing import Any, Dict, List, TYPE_CHECKING
55

66
import torch
77
from torch.fx.experimental._backward_state import BackwardState
@@ -214,6 +214,9 @@ def is_device_mesh(value):
214214
def as_python_constant(self):
215215
return self.value
216216

217+
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
218+
raise NotImplementedError
219+
217220
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
218221
if name == "ndim":
219222
return ConstantVariable.create(self.value.ndim)

torch/_dynamo/variables/functions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None:
164164
self.is_constant = False
165165

166166
assert isinstance(
167-
fn, (types.FunctionType, torch.jit.ScriptFunction)
167+
fn,
168+
(types.BuiltinFunctionType, types.FunctionType, torch.jit.ScriptFunction),
168169
), f"expected FunctionType found {typestr(fn)} {fn}"
169170
# TODO(anijain2305) - Replace directly calling UserFunctionVariable with
170171
# VariableBuilder, which handles the wrapping of _torchdynamo_inline.

torch/_dynamo/variables/lists.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def debug_repr_helper(self, prefix, suffix):
8282
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
8383

8484
def as_python_constant(self):
85-
return self.python_type()([x.as_python_constant() for x in self.items])
85+
try:
86+
return self.python_type()([x.as_python_constant() for x in self.items])
87+
except RecursionError:
88+
unimplemented(f"recursive containment {self}")
8689

8790
def as_proxy(self):
8891
assert self.python_type() is not SizeVariable

torch/_dynamo/variables/misc.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import types
1111
import warnings
12-
from typing import Dict, List, Optional, TYPE_CHECKING
12+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
1313

1414
import torch._C
1515
import torch._numpy as tnp
@@ -1617,6 +1617,9 @@ def python_type(self):
16171617
def as_python_constant(self):
16181618
return self.random
16191619

1620+
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
1621+
raise NotImplementedError
1622+
16201623
@staticmethod
16211624
def is_supported_random_obj(val):
16221625
if type(val) is not random.Random:

torch/_dynamo/variables/tensor.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
229229
# (1) the tensor is a traceable tensor subclass
230230
# (2) We are getattr'ing an inner tensor from that subclass
231231
if not self.source and is_traceable_wrapper_subclass(fake_val):
232-
fake_val = self.proxy.node.meta["example_value"]
233232
attrs, ctx = fake_val.__tensor_flatten__()
234233
proxy = getattr(self.as_proxy(), name)
235234
example_value = getattr(fake_val, name)
@@ -245,14 +244,19 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
245244
return VariableTracker.build(tx, example_value)
246245

247246
if not (self.source and self.source.subguards_allowed()):
248-
raise NotImplementedError
247+
return
248+
249+
from ..guards import get_closure_vars, GuardBuilder
249250

250251
# For local source, we associate the real value. We use this real value
251252
# for implementing getattr fallthrough on the variable tracker base class.
252-
253253
# Note - this scope construction is mirrored in guards
254254
# A subsequent PR will introduce a util.
255-
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
255+
scope = {
256+
"L": tx.output.local_scope,
257+
"G": tx.output.global_scope,
258+
**get_closure_vars(),
259+
}
256260
try:
257261
# We raise in case we get a typerror bug w/ SuperSource.
258262
# SuperSource has bugs in it atm, and can produce code like
@@ -261,23 +265,24 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
261265
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
262266
_input_associated_real_value = eval(self.source.name(), scope)
263267
except Exception as exc:
264-
raise NotImplementedError from exc
268+
msg = f"{exc!r} raised in eval('{self.source.name()}')"
269+
raise NotImplementedError(msg) from exc
265270

271+
real_value = getattr(_input_associated_real_value, name)
266272
if _input_associated_real_value is None:
267-
raise NotImplementedError
273+
return
268274

269275
if object_has_getattribute(_input_associated_real_value):
270-
raise NotImplementedError
276+
return
271277

272278
if get_custom_getattr(_input_associated_real_value):
273-
raise NotImplementedError
279+
return
274280

275-
real_value = getattr(_input_associated_real_value, name)
276281
if callable(real_value):
277282
# Callables have more nuanced handling, and we should let the existing system delegate here.
278283
# Raising was past behavior and so should always be sound to fall back.
279284
# Note - at a certain point we may want to handle
280-
raise NotImplementedError
285+
return
281286

282287
attr_source = AttrSource(self.source, name)
283288
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
@@ -1215,8 +1220,6 @@ def var_getattr(self, tx: "InstructionTranslator", name):
12151220
from ..utils import numpy_attr_wrapper
12161221
from .builder import wrap_fx_proxy
12171222

1218-
result = None
1219-
12201223
example_value = self.as_proxy().node.meta["example_value"]
12211224
example_ndarray = tnp.ndarray(example_value)
12221225

@@ -1235,7 +1238,7 @@ def insert_into_graph():
12351238
(self.as_proxy(), name),
12361239
{},
12371240
)
1238-
result = NumpyNdarrayVariable.create(tx, proxy)
1241+
return NumpyNdarrayVariable.create(tx, proxy)
12391242

12401243
# These are awkward to implement. The standard playbook for torch._numpy
12411244
# interop is to trace a call into the torch._numpy wrapper which works for
@@ -1264,9 +1267,8 @@ def insert_into_graph():
12641267
unimplemented(f"TODO: add support for ndarray.{name}")
12651268
elif name in ["__version__"]:
12661269
unimplemented("delegate np.__version__ to NumPy")
1267-
if result is None:
1268-
raise NotImplementedError
1269-
return result
1270+
else:
1271+
return super().var_getattr(tx, name)
12701272

12711273
@staticmethod
12721274
def patch_args(name, args, kwargs):

torch/_dynamo/variables/user_defined.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
181181
):
182182
return super().var_getattr(tx, name)
183183

184-
try:
185-
obj = inspect.getattr_static(self.value, name)
186-
except AttributeError:
187-
obj = None
184+
obj = inspect.getattr_static(self.value, name, None)
188185

189186
if isinstance(obj, staticmethod):
190187
return VariableTracker.build(tx, obj.__get__(self.value), source)
@@ -206,6 +203,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
206203
):
207204
return VariableTracker.build(tx, obj.__get__(self.value), source)
208205

206+
if inspect.ismemberdescriptor(obj) or inspect.isdatadescriptor(obj):
207+
value = getattr(self.value, name)
208+
return VariableTracker.build(tx, value, source)
209209
if ConstantVariable.is_literal(obj):
210210
return ConstantVariable.create(obj)
211211
elif isinstance(obj, enum.Enum):

0 commit comments

Comments
 (0)