Skip to content

Commit b0a597f

Browse files
recpytorchmergebot
authored andcommitted
Fix pytorch#121334: graph break on constant method call (pytorch#130158)
Pull Request resolved: pytorch#130158 Approved by: https://github.com/lezcano
1 parent 4865c64 commit b0a597f

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

test/dynamo/test_functions.py

+34
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,40 @@ def test_methodcall2(a, b):
330330
def test_methodcall3(a, b):
331331
return constant3(a, b=1.0) + b
332332

333+
def test_is_integer(self):
334+
@torch.compile(backend="eager", fullgraph=True)
335+
def forward(t, m):
336+
return 2 * t if m.is_integer() else t
337+
338+
t = torch.tensor([1])
339+
self.assertEqual(forward(t, 1.0).item(), 2)
340+
self.assertEqual(forward(t, 1.5).item(), 1)
341+
342+
@parametrize(
343+
"method, num_type",
344+
(
345+
("as_integer_ratio", int),
346+
("bit_length", int),
347+
("conjugate", int),
348+
("as_integer_ratio", float),
349+
("conjugate", float),
350+
("hex", float),
351+
("is_integer", float),
352+
),
353+
)
354+
def test_number_method(self, method, num_type):
355+
def forward(t, m):
356+
return 2 * t if getattr(m, method)() else t
357+
358+
wrapped = torch.compile(backend="eager", fullgraph=True)(forward)
359+
360+
for i in (0, 1, 2.5):
361+
m = num_type(i)
362+
t = torch.tensor([1])
363+
actual = wrapped(t, m)
364+
expected = forward(t, m)
365+
self.assertEqual(actual, expected)
366+
333367
@make_test
334368
def test_device_constant(a):
335369
return a + torch.ones(1, device=torch.device("cpu"))

torch/__init__.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,23 @@ def _sympy_(self):
511511
return self.node.expr
512512

513513
def __hash__(self) -> builtins.int:
514+
return hash(self._get_int())
515+
516+
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
517+
"""Represent this int as an exact integer ratio"""
518+
return self._get_int(), 1
519+
520+
def bit_length(self) -> "SymInt":
521+
return SymInt(self.node.wrap_int(self._get_int().bit_length()))
522+
523+
def conjugate(self) -> "SymInt":
524+
return self
525+
526+
def _get_int(self) -> builtins.int:
514527
if self.node.is_nested_int():
515-
return hash(self.node.nested_int())
528+
return self.node.nested_int()
516529
else:
517-
# Force specialization
518-
return hash(builtins.int(self))
530+
return builtins.int(self)
519531

520532

521533
class SymFloat:
@@ -615,18 +627,21 @@ def is_integer(self):
615627
"""Return True if the float is an integer."""
616628
raise TypeError("type stub not overridden")
617629

630+
def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]:
631+
"""Represent this float as an exact integer ratio"""
632+
return self._get_float().as_integer_ratio()
633+
618634
def __repr__(self):
619635
return self.node.str()
620636

621637
def _sympy_(self):
622638
return self.node.expr
623639

624640
def __hash__(self):
625-
if self.node.is_constant():
626-
return hash(self.node.float_())
627-
else:
628-
# Force specialization
629-
return hash(builtins.float(self))
641+
return hash(self._get_float())
642+
643+
def _get_float(self) -> builtins.float:
644+
return self.node.float_() if self.node.is_constant() else builtins.float(self)
630645

631646

632647
class SymBool:

torch/_dynamo/variables/constant.py

+22-24
Original file line numberDiff line numberDiff line change
@@ -155,33 +155,31 @@ def call_method(
155155
except NotImplementedError:
156156
return super().call_method(tx, name, args, kwargs)
157157

158-
def has_arith_binop(num_ty):
159-
return (
160-
isinstance(self.value, num_ty)
161-
and hasattr(operator, name)
162-
and len(args) == 1
163-
and args[0].is_python_constant()
164-
)
165-
166158
if isinstance(self.value, str) and name in str.__dict__.keys():
167159
method = getattr(self.value, name)
168160
return ConstantVariable.create(method(*const_args, **const_kwargs))
169-
elif has_arith_binop(int) or has_arith_binop(float):
170-
op = getattr(operator, name)
171-
add_target = const_args[0]
172-
if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
173-
from .tensor import SymNodeVariable
174-
175-
# Addition between a non sym and sym makes a sym
176-
# sym_num = tx.output.register_attr_or_module(
177-
# add_target, f"sym_shape_{add_target}", source=None
178-
# )
179-
proxy = tx.output.create_proxy(
180-
"call_function", op, (self.value, add_target), {}
181-
)
182-
return SymNodeVariable.create(tx, proxy, add_target)
183-
return ConstantVariable.create(op(self.value, add_target))
184-
elif name == "__len__" and not (args or kwargs):
161+
elif isinstance(self.value, (float, int)):
162+
if not (args or kwargs):
163+
return ConstantVariable.create(getattr(self.value, name)())
164+
if (
165+
hasattr(operator, name)
166+
and len(args) == 1
167+
and args[0].is_python_constant()
168+
):
169+
add_target = const_args[0]
170+
op = getattr(operator, name)
171+
if isinstance(
172+
add_target, (torch.SymBool, torch.SymFloat, torch.SymInt)
173+
):
174+
# Addition between a non sym and sym makes a sym
175+
proxy = tx.output.create_proxy(
176+
"call_function", op, (self.value, add_target), {}
177+
)
178+
return SymNodeVariable.create(tx, proxy, add_target)
179+
else:
180+
return ConstantVariable.create(op(self.value, add_target))
181+
182+
if name == "__len__" and not (args or kwargs):
185183
return ConstantVariable.create(len(self.value))
186184
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
187185
assert not kwargs

0 commit comments

Comments
 (0)