|
1 |
| -#include <ATen/core/SingletonSymNodeImpl.h> |
| 1 | +#include <ATen/core/NestedIntSymNodeImpl.h> |
2 | 2 | #include <c10/core/SymNodeImpl.h>
|
3 | 3 | #include <c10/util/Exception.h>
|
4 | 4 |
|
5 | 5 | namespace c10 {
|
6 | 6 |
|
7 | 7 | namespace {
|
8 | 8 | bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
9 |
| - TORCH_INTERNAL_ASSERT(lhs->singleton_int().has_value()); |
10 |
| - c10::optional<int64_t> c = rhs->singleton_int(); |
| 9 | + TORCH_INTERNAL_ASSERT(lhs->nested_int().has_value()); |
| 10 | + c10::optional<int64_t> c = rhs->nested_int(); |
11 | 11 | return (
|
12 |
| - c.has_value() && lhs->singleton_int() == *c && |
13 |
| - lhs->singleton_coeff() == rhs->singleton_coeff()); |
| 12 | + c.has_value() && lhs->nested_int() == *c && |
| 13 | + lhs->nested_int_coeff() == rhs->nested_int_coeff()); |
14 | 14 | }
|
15 | 15 | bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
16 |
| - if (auto mb_si = lhs->singleton_int()) { |
17 |
| - if (auto mb_si2 = rhs->singleton_int()) { |
| 16 | + if (auto mb_si = lhs->nested_int()) { |
| 17 | + if (auto mb_si2 = rhs->nested_int()) { |
18 | 18 | if (*mb_si == *mb_si2) {
|
19 |
| - return lhs->singleton_coeff() >= rhs->singleton_coeff(); |
| 19 | + return lhs->nested_int_coeff() >= rhs->nested_int_coeff(); |
20 | 20 | }
|
21 |
| - TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate"); |
| 21 | + TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate"); |
22 | 22 | }
|
23 | 23 | // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
24 | 24 | if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
25 | 25 | return true;
|
26 | 26 | }
|
27 |
| - TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate"); |
28 |
| - } else if (rhs->singleton_int()) { |
| 27 | + TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate"); |
| 28 | + } else if (rhs->nested_int()) { |
29 | 29 | // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
30 | 30 | if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
31 | 31 | return false;
|
32 | 32 | }
|
33 |
| - TORCH_CHECK(false, "Singleton int ", op, ": Relation is indeterminate"); |
| 33 | + TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate"); |
34 | 34 | }
|
35 |
| - TORCH_INTERNAL_ASSERT(false, "expect at least one singleton"); |
| 35 | + TORCH_INTERNAL_ASSERT(false, "expect at least one nested int"); |
36 | 36 | }
|
37 | 37 | } // namespace
|
38 | 38 |
|
39 |
| -c10::SymNode SingletonSymNodeImpl::eq(const c10::SymNode& other) { |
| 39 | +c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) { |
40 | 40 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
41 | 41 | _eq("eq", this, other.get())));
|
42 | 42 | }
|
43 | 43 |
|
44 |
| -c10::SymNode SingletonSymNodeImpl::ne(const c10::SymNode& other) { |
| 44 | +c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) { |
45 | 45 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
46 | 46 | !_eq("ne", this, other.get())));
|
47 | 47 | }
|
48 | 48 |
|
49 |
| -c10::SymNode SingletonSymNodeImpl::ge(const c10::SymNode& other) { |
| 49 | +c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) { |
50 | 50 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
51 | 51 | _ge("ge", this, other.get())));
|
52 | 52 | }
|
53 | 53 |
|
54 |
| -c10::SymNode SingletonSymNodeImpl::gt(const c10::SymNode& other) { |
| 54 | +c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) { |
55 | 55 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
56 | 56 | !_ge("gt", other.get(), this)));
|
57 | 57 | }
|
58 | 58 |
|
59 |
| -c10::SymNode SingletonSymNodeImpl::lt(const c10::SymNode& other) { |
| 59 | +c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) { |
60 | 60 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
61 | 61 | !_ge("lt", this, other.get())));
|
62 | 62 | }
|
63 | 63 |
|
64 |
| -c10::SymNode SingletonSymNodeImpl::le(const c10::SymNode& other) { |
| 64 | +c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) { |
65 | 65 | return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
66 | 66 | _ge("le", other.get(), this)));
|
67 | 67 | }
|
68 | 68 |
|
69 |
| -c10::SymNode SingletonSymNodeImpl::mul(const c10::SymNode& other) { |
70 |
| - if (auto mb_si = other->singleton_int()) { |
71 |
| - TORCH_CHECK(false, "Singleton int cannot be multiplied by singleton int"); |
| 69 | +c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) { |
| 70 | + if (auto mb_si = other->nested_int()) { |
| 71 | + TORCH_CHECK(false, "nested int cannot be multiplied by nested int"); |
72 | 72 | }
|
73 | 73 | c10::optional<int64_t> c = other->constant_int();
|
74 | 74 | TORCH_CHECK(c.has_value());
|
75 |
| - return SymNode(c10::make_intrusive<SingletonSymNodeImpl>(val_, coeff_ * *c)); |
| 75 | + return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c)); |
76 | 76 | }
|
77 | 77 |
|
78 | 78 | } // namespace c10
|
0 commit comments