Skip to content

Commit 312ce35

Browse files
soulitzerpytorchmergebot
authored andcommitted
Rename singleton int to nested int (pytorch#119661)
Pull Request resolved: pytorch#119661 Approved by: https://github.com/ezyang
1 parent b97fa6a commit 312ce35

21 files changed

+99
-99
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,78 @@
1-
#include <ATen/core/SingletonSymNodeImpl.h>
1+
#include <ATen/core/NestedIntSymNodeImpl.h>
22
#include <c10/core/SymNodeImpl.h>
33
#include <c10/util/Exception.h>
44

55
namespace c10 {
66

77
namespace {
88
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
9-
TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value());
10-
c10::optional<int64_t> c = rhs->singleton_int();
9+
TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value());
10+
c10::optional<int64_t> c = rhs->nested_int();
1111
return (
12-
c.has_value() && lhs->singleton_int() == *c &&
13-
lhs->singleton_coeff() == rhs->singleton_coeff());
12+
c.has_value() && lhs->nested_int() == *c &&
13+
lhs->nested_int_coeff() == rhs->nested_int_coeff());
1414
}
1515
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
16-
if (auto mb_si = lhs->singleton_int()) {
17-
if (auto mb_si2 = rhs->singleton_int()) {
16+
if (auto mb_si = lhs->nested_int()) {
17+
if (auto mb_si2 = rhs->nested_int()) {
1818
if (*mb_si == *mb_si2) {
19-
return lhs->singleton_coeff() >= rhs->singleton_coeff();
19+
return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
2020
}
21-
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
21+
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
2222
}
2323
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2424
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
2525
return true;
2626
}
27-
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
28-
} else if (rhs->singleton_int()) {
27+
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
28+
} else if (rhs->nested_int()) {
2929
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
3030
if (lhs->constant_int() && *lhs->constant_int() < 2) {
3131
return false;
3232
}
33-
TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate");
33+
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
3434
}
35-
TORCH_INTERNAL_ASSERT(false, "expect at least one singleton");
35+
TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
3636
}
3737
} // namespace
3838

39-
c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) {
39+
c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
4040
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
4141
_eq("eq", this, other.get())));
4242
}
4343

44-
c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) {
44+
c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
4545
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
4646
!_eq("ne", this, other.get())));
4747
}
4848

49-
c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) {
49+
c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
5050
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
5151
_ge("ge", this, other.get())));
5252
}
5353

54-
c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) {
54+
c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
5555
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
5656
!_ge("gt", other.get(), this)));
5757
}
5858

59-
c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) {
59+
c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
6060
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
6161
!_ge("lt", this, other.get())));
6262
}
6363

64-
c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) {
64+
c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
6565
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
6666
_ge("le", other.get(), this)));
6767
}
6868

69-
c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) {
70-
if (auto mb_si = other->singleton_int()) {
71-
TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int");
69+
c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
70+
if (auto mb_si = other->nested_int()) {
71+
TORCH_CHECK(false, "nested int cannot be multiplied by nested int");
7272
}
7373
c10::optional<int64_t> c = other->constant_int();
7474
TORCH_CHECK(c.has_value());
75-
return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c));
75+
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
7676
}
7777

7878
} // namespace c10

aten/src/ATen/core/SingletonSymNodeImpl.h aten/src/ATen/core/NestedIntSymNodeImpl.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ namespace c10 {
1616
// allows us to simply return [B, j0, D] if someone queries for the size of our
1717
// tensor.
1818
//
19-
// Morally we define comparison between two singleton ints to return true if
19+
// Morally we define comparison between two nested ints to return true if
2020
// that comparison holds for all corresponding elements of the arrays they
21-
// represent. Comparison between a singleton int and a plain int is defined
21+
// represent. Comparison between a nested int and a plain int is defined
2222
// similarly.
2323
//
2424
// To simulate this desired behavior but also avoid the O(N) cost of checking,
@@ -32,13 +32,13 @@ namespace c10 {
3232
// differentiate the two cases.
3333
//
3434
// During tracing the strides of the outputs need to be a function of the size
35-
// and strides of the inputs so it is important that SingletonSymNode itself is
35+
// and strides of the inputs so it is important that NestedIntSymNode itself is
3636
// able to express this.
37-
class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
37+
class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
3838
public:
3939
// CAUTION: you should probably not be constructing these directly; please
4040
// the higher-level API in python instead (TODO: actually introduce that).
41-
explicit SingletonSymNodeImpl(int64_t val, int64_t coeff)
41+
explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
4242
: val_(val), coeff_(coeff) {}
4343

4444
bool bool_() override {
@@ -88,9 +88,9 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
8888
return std::to_string(coeff_) + "*j" + std::to_string(val_);
8989
}
9090

91-
// NOTE [ Inequalities with SingletonInt ]
91+
// NOTE [ Inequalities with nested int ]
9292
//
93-
// The semantics of SingletonInt when it comes to relations is that it is
93+
// The semantics of nested int when it comes to relations is that it is
9494
// treated as integer known to be within a certain range,
9595
//
9696
// j0 \in [2, int64_t::max]
@@ -117,7 +117,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
117117
// [ Coefficient are assumed positive ]
118118
//
119119
// For the purpose of computing inequalities, we consider the coefficient of
120-
// the SingletonInt to be a positive integer.
120+
// the nested int to be a positive integer.
121121
//
122122
// Thus, no modifications are needed to the logic since
123123
// j0 >= k implies coeff * j0 >= k
@@ -130,11 +130,11 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
130130
c10::SymNode le(const c10::SymNode& other) override;
131131
c10::SymNode mul(const c10::SymNode& other) override;
132132

133-
c10::optional<int64_t> singleton_int() override {
133+
c10::optional<int64_t> nested_int() override {
134134
return val_;
135135
}
136136

137-
c10::optional<int64_t> singleton_coeff() override {
137+
c10::optional<int64_t> nested_int_coeff() override {
138138
return coeff_;
139139
}
140140

@@ -144,7 +144,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
144144

145145
#define DEFINE_BINARY_NOT_SUPPORTED(name) \
146146
c10::SymNode name(const c10::SymNode& other) override { \
147-
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
147+
TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
148148
}
149149

150150
DEFINE_BINARY_NOT_SUPPORTED(add)
@@ -162,7 +162,7 @@ class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
162162

163163
#define DEFINE_NOT_SUPPORTED(name) \
164164
c10::SymNode name() override { \
165-
TORCH_CHECK(false, #name " is not supported by SingletonSymNode"); \
165+
TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
166166
}
167167

168168
DEFINE_NOT_SUPPORTED(sym_not)

build_variables.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ aten_cpu_source_non_codegen_list = [
10371037
"aten/src/ATen/core/operator_name.cpp",
10381038
"aten/src/ATen/core/TorchDispatchUtils.cpp",
10391039
"aten/src/ATen/core/register_symbols.cpp",
1040-
"aten/src/ATen/core/SingletonSymNodeImpl.cpp",
1040+
"aten/src/ATen/core/NestedIntSymNodeImpl.cpp",
10411041
"aten/src/ATen/core/class_type.cpp",
10421042
"aten/src/ATen/core/type.cpp",
10431043
"aten/src/ATen/core/type_factory.cpp",

c10/core/ConstantSymNodeImpl.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ namespace c10 {
44

55
// This is used to support the case where the lhs is a constant symnode
66
// and the rhs is a singleton symnode. This situation occurs today when we
7-
// perform a binary op between singleton int and plain int and the
7+
// perform a binary op between nested int and plain int and the
88
// singleton promotes the int into a constant symnode. If we'd like to
99
// support more combinations in the future, we may need to implement some
1010
// kind of multiple dispatch.
1111
#define DEFINE_BINARY_OP(OP, ROP) \
1212
template <typename T> \
1313
c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) { \
14-
TORCH_INTERNAL_ASSERT(other->singleton_int().has_value()); \
14+
TORCH_INTERNAL_ASSERT(other->nested_int().has_value()); \
1515
return other->ROP( \
1616
c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
1717
}

c10/core/SymNodeImpl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,10 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
185185
virtual std::string str() {
186186
TORCH_CHECK(false, "NYI");
187187
};
188-
virtual c10::optional<int64_t> singleton_int() {
188+
virtual c10::optional<int64_t> nested_int() {
189189
return c10::nullopt;
190190
}
191-
virtual c10::optional<int64_t> singleton_coeff() {
191+
virtual c10::optional<int64_t> nested_int_coeff() {
192192
return c10::nullopt;
193193
}
194194
virtual c10::optional<int64_t> constant_int() {

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@
944944
"is_channels_last_strides_3d",
945945
"is_contiguous",
946946
"is_non_overlapping_and_dense_indicator",
947-
"is_singleton",
947+
"is_nested_int",
948948
"is_symbol_binding_fx_node",
949949
"is_symbolic",
950950
# torch.fx.experimental.unification.core

test/cpp/api/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ set(TORCH_API_TEST_SOURCES
4141
${TORCH_API_TEST_DIR}/inference_mode.cpp
4242
${TORCH_API_TEST_DIR}/grad_mode.cpp
4343
${TORCH_API_TEST_DIR}/operations.cpp
44-
${TORCH_API_TEST_DIR}/singleton_int.cpp
44+
${TORCH_API_TEST_DIR}/nested_int.cpp
4545
)
4646
if(USE_CUDA OR USE_ROCM)
4747
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)

test/cpp/api/singleton_int.cpp test/cpp/api/nested_int.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
#include <gtest/gtest.h>
22

3-
#include <ATen/core/SingletonSymNodeImpl.h>
3+
#include <ATen/core/NestedIntSymNodeImpl.h>
44
#include <c10/core/SymInt.h>
55
#include <c10/core/SymNodeImpl.h>
66
#include <torch/torch.h>
77

88
#include <test/cpp/api/support.h>
99

10-
TEST(SingletonIntTest, Comparisons) {
10+
TEST(NestedIntTest, Comparisons) {
1111
auto a = c10::SymInt(
12-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
12+
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
1313
auto b = c10::SymInt(
14-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
14+
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
1515
auto c = c10::SymInt(
16-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
16+
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
1717
auto d = c10::SymInt(3);
1818

1919
ASSERT_TRUE(a == a);
@@ -85,11 +85,11 @@ TEST(SingletonIntTest, Comparisons) {
8585
ASSERT_TRUE(a > 1);
8686
}
8787

88-
TEST(SingletonIntTest, WiithFactor) {
88+
TEST(NestedIntTest, WithFactor) {
8989
auto a = c10::SymInt(
90-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
90+
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
9191
auto b = c10::SymInt(
92-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
92+
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
9393
// eq
9494
ASSERT_FALSE(a == b);
9595
ASSERT_FALSE(a >= b);

test/test_dynamic_shapes.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -849,10 +849,10 @@ def test_symnode_hashing(self):
849849
with self.assertRaisesRegex(TypeError, "unhashable"):
850850
hash(x)
851851

852-
# Singleton SymInt, constant SymBool, SymNode are hashable
853-
j1 = torch._C._get_singleton_int(1, 1)
854-
j1_copy = torch._C._get_singleton_int(1, 1)
855-
j2 = torch._C._get_singleton_int(2, 1)
852+
# NestedInt (SymInt), constant SymBool, SymNode are hashable
853+
j1 = torch._C._get_nested_int(1, 1)
854+
j1_copy = torch._C._get_nested_int(1, 1)
855+
j2 = torch._C._get_nested_int(2, 1)
856856
t = self.get_constant_bool(True)
857857
t_copy = self.get_constant_bool(True)
858858
f = self.get_constant_bool(False)
@@ -872,14 +872,14 @@ def test_symnode_hashing(self):
872872
hash(m)
873873

874874
def test_non_symbolic_symnode(self):
875-
j1 = torch._C._get_singleton_int(1, 1)
876-
j2 = torch._C._get_singleton_int(1, 1)
877-
j3 = torch._C._get_singleton_int(3, 1)
875+
j1 = torch._C._get_nested_int(1, 1)
876+
j2 = torch._C._get_nested_int(1, 1)
877+
j3 = torch._C._get_nested_int(3, 1)
878878

879879
self.assertIsInstance(j1, torch.SymInt)
880880
self.assertNotIsInstance(j1, int)
881881

882-
with self.assertRaisesRegex(RuntimeError, "add not supported by SingletonSymNode"):
882+
with self.assertRaisesRegex(RuntimeError, "add not supported by NestedIntSymNode"):
883883
j1 + 3
884884

885885
self.assertFalse(j1 == 3)

test/test_nestedtensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3028,9 +3028,9 @@ def test_tensor_attributes(self, device):
30283028
"directly calling torch.ops.aten.size"):
30293029
torch.ops.aten.size.default(nt)
30303030

3031-
singleton_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
3032-
self.assertEqual(nt.size(), (3, singleton_int, 3))
3033-
self.assertEqual(nt.shape, (3, singleton_int, 3))
3031+
nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
3032+
self.assertEqual(nt.size(), (3, nested_int, 3))
3033+
self.assertEqual(nt.shape, (3, nested_int, 3))
30343034
self.assertEqual(nt.dim(), 3)
30353035
self.assertEqual(nt.numel(), 27)
30363036

torch/_C/__init__.pyi.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -1533,7 +1533,7 @@ def _are_functorch_transforms_active() -> _bool: ...
15331533
# Define in torch/csrc/autograd/init.cpp
15341534
def _set_python_dispatcher(dispatcher: object) -> None: ...
15351535

1536-
def _get_singleton_int(id: _int, coeff: _int) -> SymInt: ...
1536+
def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
15371537

15381538
def _get_constant_bool_symnode(val: _bool) -> Any: ...
15391539

torch/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ def __repr__(self):
301301
return str(self.node)
302302

303303
def __hash__(self) -> builtins.int:
304-
ret = self.node.singleton_int()
304+
ret = self.node.nested_int()
305305
if ret is not None:
306306
return hash(ret)
307307
else:
308308
# We could support constant SymInts as well, but not doing it for now
309-
raise TypeError("unhashable type: non-singleton SymInt")
309+
raise TypeError("unhashable type: non-nested SymInt")
310310

311311
class SymFloat:
312312
"""

torch/_dynamo/trace_rules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@
491491
"torch._C._get_privateuse1_backend_name",
492492
"torch._C._get_qengine",
493493
"torch._C._get_schema",
494-
"torch._C._get_singleton_int",
494+
"torch._C._get_nested_int",
495495
"torch._C._get_tensor_metadata",
496496
"torch._C._get_tracing_state",
497497
"torch._C._get_upgrader_ranges",

torch/_dynamo/variables/builder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1598,9 +1598,9 @@ def _automatic_dynamic(
15981598

15991599
# We preserve the dynamism of inputs. For example, when users call
16001600
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
1601-
from torch.fx.experimental.symbolic_shapes import is_singleton
1601+
from torch.fx.experimental.symbolic_shapes import is_nested_int
16021602

1603-
if any(isinstance(s, SymInt) and not is_singleton(s) for s in e.size()):
1603+
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()):
16041604
return StatefulSymbolicContext(
16051605
dynamic_sizes=[
16061606
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
@@ -1729,7 +1729,7 @@ def update_dim2constraint(dim, constraint_range, debug_name):
17291729
constraint_dim is not None
17301730
or marked_dynamic
17311731
or marked_weak_dynamic
1732-
or is_singleton(e.shape[i])
1732+
or is_nested_int(e.shape[i])
17331733
):
17341734
# NB: We could assert static_shapes is False here, but it
17351735
# seems better to allow the user to override symbolic_context in this

0 commit comments

Comments
 (0)