@@ -30,6 +30,9 @@ struct FunctionSchema;
30
30
struct NamedType ;
31
31
using OptNameList = c10::optional<std::vector<std::string>>;
32
32
33
+ void standardizeVectorForUnion (std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
34
+ void standardizeVectorForUnion (std::vector<TypePtr>* to_flatten);
35
+
33
36
struct AnyType ;
34
37
using AnyTypePtr = std::shared_ptr<AnyType>;
35
38
// Any is the top of the type hierarchy, all other types are subtypes
@@ -94,25 +97,84 @@ struct SingleElementType : public Type {
94
97
TypePtr elem;
95
98
};
96
99
100
+ struct UnionType ;
101
+ using UnionTypePtr = std::shared_ptr<UnionType>;
102
+ struct TORCH_API UnionType : public Type {
103
+ friend struct Type ;
104
+
105
+ static const TypeKind Kind = TypeKind::UnionType;
106
+
107
+ bool isSubtypeOfExt (const TypePtr& rhs_, std::ostream* why_not) const override ;
108
+
109
+ std::string str () const override ;
110
+
111
+ static UnionTypePtr create (std::vector<TypePtr> reference);
112
+
113
+ bool operator ==(const Type& rhs) const override ;
114
+
115
+ at::ArrayRef<TypePtr> containedTypes () const override {
116
+ return types_;
117
+ }
118
+
119
+ // For testing purposes only
120
+ at::ArrayRef<TypePtr> getTypes () const {
121
+ return types_;
122
+ }
123
+
124
+ TypePtr createWithContained (std::vector<TypePtr> contained_types) const override {
125
+ return create (contained_types);
126
+ }
127
+
128
+ bool canHoldType (TypePtr type) const ;
129
+
130
+ bool hasFreeVariables () const override {
131
+ return has_free_variables_;
132
+ }
133
+
134
+ c10::optional<TypePtr> toOptional () const ;
135
+
136
+ c10::optional<TypePtr> subtractTypeSet (std::vector<TypePtr>& to_subtract) const ;
137
+
138
+ protected:
139
+ explicit UnionType (std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
140
+ std::string annotation_str_impl (TypePrinter printer = nullptr ) const override ;
141
+ std::string unionStr (TypePrinter printer = nullptr , bool is_annotation_str = false ) const ;
142
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
143
+ bool has_free_variables_;
144
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
145
+ std::vector<TypePtr> types_;
146
+ // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
147
+ bool can_hold_none_;
148
+
149
+ };
150
+
97
151
struct OptionalType ;
98
152
using OptionalTypePtr = std::shared_ptr<OptionalType>;
99
- // This type represents an optional type, for each element type.
100
- // Optional[T] can accept both T and None(nullopt in C++)
153
+ // This type represents an optional type. There is one `Optional` for
154
+ // each element type. `Optional[T]` can accept both `T` and
155
+ // `None`(`c10::nullopt` in C++)
101
156
// Subtype hierarchy for Optional:
102
- // 1. Optional[T] <: Optional[R] iff T <: R
103
- // 2. T <: Optional[R] if T <: R
104
- // 3. None <: Optional[T] for all T
105
- struct TORCH_API OptionalType
106
- : public SingleElementType<TypeKind::OptionalType, OptionalType> {
107
- static OptionalTypePtr create (TypePtr element) {
108
- TORCH_INTERNAL_ASSERT (element, " OptionalType requires valid TypePtr" );
109
- // Optional is a union of [None, T], so Optional[[Optional[T]]] ->
110
- // Optional[T]
111
- if (auto opt_ptr = element->cast <OptionalType>()) {
112
- return opt_ptr;
113
- }
114
- return OptionalTypePtr (
115
- new OptionalType (std::move (element))); // NOLINT(modernize-make-shared)
157
+ // - Optional[T] <: Optional[R] iff T <: R
158
+ // - T <: Optional[R] if T <: R
159
+ // - None <: Optional[T] for all T
160
+ // - Optional[T] == Union[T, None] for all T
161
+ struct TORCH_API OptionalType : public UnionType {
162
+ static OptionalTypePtr create (TypePtr contained) {
163
+ return OptionalTypePtr (new OptionalType (std::move (contained)));
164
+ }
165
+
166
+ static const TypeKind Kind = TypeKind::OptionalType;
167
+
168
+ friend struct Type ;
169
+
170
+ bool operator ==(const Type& rhs) const override ;
171
+
172
+ TypePtr getElementType () const {
173
+ return contained_;
174
+ }
175
+
176
+ at::ArrayRef<TypePtr> containedTypes () const override {
177
+ return contained_;
116
178
}
117
179
118
180
std::string str () const override {
@@ -127,20 +189,15 @@ struct TORCH_API OptionalType
127
189
return create (contained_types[0 ]);
128
190
}
129
191
130
- bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override {
131
- if (Type::isSubtypeOfExt (rhs, why_not)) {
132
- return true ;
133
- }
134
- if (auto rhs_ = rhs->cast <OptionalType>()) {
135
- return getElementType ()->isSubtypeOfExt (rhs_->getElementType (), why_not);
136
- }
137
- return false ;
138
- }
192
+ bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override ;
193
+
139
194
// common cast Optional[Tensor] for undefined tensor type
140
195
static OptionalTypePtr ofTensor ();
141
196
142
197
private:
143
- OptionalType (TypePtr elem) : SingleElementType(elem) {}
198
+ explicit OptionalType (TypePtr contained);
199
+
200
+ TypePtr contained_;
144
201
145
202
std::string annotation_str_impl (TypePrinter printer = nullptr ) const override {
146
203
std::stringstream ss;
@@ -908,7 +965,6 @@ struct TORCH_API RRefType
908
965
}
909
966
};
910
967
911
-
912
968
struct NamedType ;
913
969
using NamedTypePtr = std::shared_ptr<NamedType>;
914
970
using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
@@ -1112,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType {
1112
1168
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
1113
1169
};
1114
1170
1115
-
1116
1171
// the common supertype of all Enums, only used in operator registraion.
1117
1172
// EnumType <: AnyEnumType for all Enums
1118
1173
struct AnyEnumType ;
@@ -1132,7 +1187,6 @@ struct TORCH_API AnyEnumType : public Type {
1132
1187
: Type(TypeKind::AnyEnumType) {}
1133
1188
};
1134
1189
1135
-
1136
1190
struct NumberType ;
1137
1191
using NumberTypePtr = std::shared_ptr<NumberType>;
1138
1192
// This type represents a Python number
@@ -1141,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
1141
1195
// FloatType <: NumberType
1142
1196
// ComplexType <:NumberType
1143
1197
struct TORCH_API NumberType : public Type {
1144
- bool operator ==(const Type& rhs) const override {
1145
- return rhs.kind () == kind ();
1146
- }
1198
+ bool operator ==(const Type& rhs) const override ;
1199
+
1200
+ bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override ;
1201
+
1147
1202
std::string str () const override {
1148
1203
return " Scalar" ; // match what PythonArgParser says for clarity
1149
1204
}
@@ -1172,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType {
1172
1227
return " float" ;
1173
1228
}
1174
1229
bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override {
1175
- return rhs->kind () == TypeKind::NumberType || NumberType::isSubtypeOfExt (rhs, why_not);
1230
+ // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1231
+ return rhs->kind () == TypeKind::NumberType || Type::isSubtypeOfExt (rhs, why_not);
1176
1232
}
1177
1233
static const TypeKind Kind = TypeKind::FloatType;
1178
1234
// global singleton
@@ -1196,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType {
1196
1252
return " complex" ;
1197
1253
}
1198
1254
bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override {
1199
- return rhs->kind () == TypeKind::NumberType || NumberType::isSubtypeOfExt (rhs, why_not);
1255
+ // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1256
+ return rhs->kind () == TypeKind::NumberType || Type::isSubtypeOfExt (rhs, why_not);
1200
1257
}
1201
1258
static const TypeKind Kind = TypeKind::ComplexType;
1202
1259
// global singleton
@@ -1220,7 +1277,8 @@ struct TORCH_API IntType : public NumberType {
1220
1277
return " int" ;
1221
1278
}
1222
1279
bool isSubtypeOfExt (const TypePtr& rhs, std::ostream* why_not) const override {
1223
- return rhs->kind () == TypeKind::NumberType || NumberType::isSubtypeOfExt (rhs, why_not);
1280
+ // NOLINTNEXTLINE(bugprone-parent-virtual-call)
1281
+ return rhs->kind () == TypeKind::NumberType || Type::isSubtypeOfExt (rhs, why_not);
1224
1282
}
1225
1283
static const TypeKind Kind = TypeKind::IntType;
1226
1284
// global singleton
@@ -1334,12 +1392,8 @@ struct TORCH_API NoneType : public Type {
1334
1392
std::string str () const override {
1335
1393
return " NoneType" ;
1336
1394
}
1337
- bool isSubtypeOfExt (const TypePtr& rhs, std::ostream *why_not) const override {
1338
- if (rhs->kind () == OptionalType::Kind) {
1339
- return true ;
1340
- }
1341
- return Type::isSubtypeOfExt (rhs, why_not);
1342
- }
1395
+ bool isSubtypeOfExt (const TypePtr& rhs, std::ostream *why_not) const override ;
1396
+
1343
1397
static const TypeKind Kind = TypeKind::NoneType;
1344
1398
// global singleton
1345
1399
static NoneTypePtr get ();
@@ -1524,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
1524
1578
// what is the type, ignoring extra size/shape information?
1525
1579
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
1526
1580
1527
- // xxx: be careful with calls because this can be very slow. If calling this on a graph
1528
- // use `EraseShapeInformation` in shape_analysis.h
1581
+ // `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
1582
+ // subtypes as simply "Tensor"; we also create a new version of any
1583
+ // container types in which internal Tensors have undergone the same
1584
+ // operation. This is used for type comparisons between two Tensor types
1585
+ // (`unshapedType` means that we don't falsely return `false` for e.g.
1586
+ // Tensors of different dimensions). It's also used in the alias
1587
+ // analysis pass.
1588
+ // Be careful with calls because this can be very slow. If calling this
1589
+ // on a graph, use `EraseShapeInformation` in shape_analysis.h
1529
1590
inline TypePtr unshapedType (const TypePtr& type) {
1530
1591
if (type->isSubtypeOf (TensorType::get ())) {
1531
1592
return TensorType::get ();
@@ -1569,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
1569
1630
return *result;
1570
1631
}
1571
1632
1572
- // Attempt to find the correct supertype of t1 and t2. If none is found then
1573
- // nullopt will be returned if default_to_any is false, and Any will be returned
1574
- // if it is true. If t1 == t2, or t1 is a type refinement of t2,
1575
- // then t2 will be returned (and vice versa).
1633
+ // Attempt to find the correct supertype of the two types `t1` and `t2`.
1634
+ // If no supertype is found, then nullopt will be returned if
1635
+ // `default_to_union` is false, and `Union[t1, t2]` will be returned
1636
+ // if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
1637
+ // then `t2` will be returned (and vice versa).
1638
+ //
1576
1639
// Two different tensortypes will return dynamic.
1577
- // Currently we chose not to support returning a NumberType for a float & int
1578
- // input because of a lack of operator support for NumberType.
1640
+ //
1641
+ // Currently we chose not to support returning a NumberType for
1642
+ // two types from the set of {FloatType, IntType, ComplexType}, because
1643
+ // there is a lack of operator support for NumberType.
1644
+ //
1579
1645
// If `type_hint` is an `InterfaceType`, then we can use that as a
1580
1646
// potential supertype for `ClassType`s in the list. Otherwise, we have
1581
1647
// no way to find and use some common interface type
1582
1648
TORCH_API c10::optional<TypePtr> unifyTypes (
1583
1649
const TypePtr& t1,
1584
1650
const TypePtr& t2,
1585
- bool default_to_any = false ,
1586
- TypePtr type_hint= nullptr );
1651
+ bool default_to_union = false ,
1652
+ TypePtr type_hint = nullptr );
1587
1653
1588
1654
TORCH_API c10::optional<TypePtr> unifyTypeList (
1589
1655
at::ArrayRef<TypePtr> elements,
1590
1656
std::ostream& why_not,
1591
- bool default_to_any= false ,
1592
- TypePtr type_hint= nullptr );
1657
+ bool default_to_union = false ,
1658
+ TypePtr type_hint = nullptr );
1593
1659
1594
1660
namespace detail {
1595
1661
template <typename T>
0 commit comments