Skip to content

Commit 389b56b

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards-cpp-refactor] GetAttrGuardAccessor (pytorch#119833)
Pull Request resolved: pytorch#119833 Approved by: https://github.com/jansel ghstack dependencies: pytorch#119822, pytorch#119827
1 parent 96f45d1 commit 389b56b

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

test/dynamo/test_guard_manager.py

+39
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch._C._dynamo import guards
88

99
RootGuardManager = guards.RootGuardManager
10+
GetAttrGuardAccessor = guards.GetAttrGuardAccessor
1011

1112

1213
def id_type(x):
@@ -137,6 +138,44 @@ def test_guard_manager_leaf_guard(self):
137138
self.assertFalse(guard_manager.check(4))
138139
self.assertFalse(guard_manager.check("foo"))
139140

141+
def test_attr_guard_manager(self):
142+
class Foo:
143+
def __init__(self, x, y):
144+
self.x = x
145+
self.y = y
146+
147+
foo = Foo(1, 2)
148+
guard_manager = RootGuardManager()
149+
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
150+
guard_manager.getattr_manager("x", 1).add_lambda_guard(
151+
functools.partial(equals_match, expected=foo.x),
152+
equals_match_verbose_code_parts(foo.x),
153+
)
154+
guard_manager.getattr_manager("y", 2).add_lambda_guard(
155+
functools.partial(equals_match, expected=foo.y),
156+
equals_match_verbose_code_parts(foo.y),
157+
)
158+
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
159+
# 2 child managers, one for x and one for y
160+
self.assertEqual(len(guard_manager.get_accessors()), 2)
161+
self.assertTrue(
162+
isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
163+
)
164+
self.assertTrue(
165+
isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
166+
)
167+
# Check leaf guards on child managers
168+
self.assertEqual(
169+
len(guard_manager.getattr_manager("x", None).get_leaf_guards()), 1
170+
)
171+
self.assertEqual(
172+
len(guard_manager.getattr_manager("y", None).get_leaf_guards()), 1
173+
)
174+
175+
self.assertTrue(guard_manager.check(foo))
176+
self.assertFalse(guard_manager.check(Foo(3, 4)))
177+
self.assertFalse(guard_manager.check("foo"))
178+
140179

141180
if __name__ == "__main__":
142181
from torch._dynamo.test_case import run_tests

torch/csrc/dynamo/guards.cpp

+65-1
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,60 @@ std::unique_ptr<GuardManager> make_guard_manager(
13531353
return std::make_unique<GuardManager>(root);
13541354
}
13551355

1356+
/**
1357+
* Represents __getattr__ acccessor.
1358+
*/
1359+
class GetAttrGuardAccessor : public GuardAccessor {
1360+
public:
1361+
GetAttrGuardAccessor(
1362+
RootGuardManager* root,
1363+
py::str name,
1364+
py::handle example_value)
1365+
: GuardAccessor(root, name, example_value), _attr_name(name.ptr()) {}
1366+
1367+
// NB: Intentional duplication between check_nopybind and
1368+
// check_verbose_nopybind.
1369+
bool check_nopybind(PyObject* obj) override { // borrowed ref
1370+
PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
1371+
if (x == nullptr) {
1372+
// Attribute absent, clear the exception and return false.
1373+
PyErr_Clear();
1374+
return false;
1375+
}
1376+
bool result = _guard_manager->check_nopybind(x);
1377+
Py_DECREF(x);
1378+
return result;
1379+
}
1380+
1381+
GuardDebugInfo check_verbose_nopybind(
1382+
PyObject* obj) override { // borrowed ref
1383+
PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
1384+
if (x == nullptr) {
1385+
// Attribute absent, clear the exception and return false.
1386+
PyErr_Clear();
1387+
return GuardDebugInfo(
1388+
false,
1389+
std::string("get attr failed for attr name ") +
1390+
py::str(_attr_name).cast<std::string>(),
1391+
0);
1392+
}
1393+
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
1394+
Py_DECREF(x);
1395+
return result;
1396+
}
1397+
1398+
std::string repr() const override {
1399+
// Helpful when priting GuardManager tree structure.
1400+
return "GetAttrGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
1401+
")";
1402+
}
1403+
1404+
private:
1405+
// no need of py::object here because the attr_name is already passed on to
1406+
// the base class as accessor_key which is a py::object.
1407+
PyObject* _attr_name;
1408+
};
1409+
13561410
} // namespace
13571411

13581412
static void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
@@ -1458,6 +1512,10 @@ PyObject* torch_c_dynamo_guards_init() {
14581512
py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
14591513
py_m, "GuardAccessor")
14601514
.def("repr", &GuardAccessor::repr);
1515+
py::class_<
1516+
GetAttrGuardAccessor,
1517+
GuardAccessor,
1518+
std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
14611519

14621520
// Guard Manager - No constructor in python, python should use
14631521
// RootGuardManager.
@@ -1510,7 +1568,13 @@ PyObject* torch_c_dynamo_guards_init() {
15101568
py::object verbose_code_parts) -> void {
15111569
self.add_leaf_guard(
15121570
std::make_shared<EQUALS_MATCH>(value, verbose_code_parts));
1513-
});
1571+
})
1572+
// return by reference because C++ GuardManager has the ownership of
1573+
// accessors and guard managers
1574+
.def(
1575+
"getattr_manager",
1576+
&GuardManager::get_child_manager<GetAttrGuardAccessor>,
1577+
py::return_value_policy::reference);
15141578

15151579
// Root Guard Manager
15161580
py::class_<RootGuardManager, GuardManager, std::unique_ptr<RootGuardManager>>(

0 commit comments

Comments
 (0)