Skip to content

Commit 6831d8e

Browse files
Ansley Usseryfacebook-github-bot
Ansley Ussery
authored andcommitted
Support Union in TorchScript (pytorch#64234)
Summary: This PR is created to replace pytorch#53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue. Pull Request resolved: pytorch#64234 Reviewed By: gmagogsfm Differential Revision: D30656444 Pulled By: ansley fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
1 parent 91b926f commit 6831d8e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2132
-462
lines changed

CONTRIBUTING.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -435,12 +435,12 @@ is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME`, where
435435
`TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is
436436
the suite that test is defined in.
437437

438-
For example, if you wanted to run the test ` MayContainAlias`, which
438+
For example, if you wanted to run the test `MayContainAlias`, which
439439
is part of the test suite `ContainerAliasingTest` in the file
440440
`test/cpp/jit/test_alias_analysis.cpp`, the command would be:
441441

442442
```bash
443-
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.UnionAliasing
443+
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias
444444
```
445445

446446

aten/src/ATen/core/jit_type.h

+119-53
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ struct FunctionSchema;
3030
struct NamedType;
3131
using OptNameList = c10::optional<std::vector<std::string>>;
3232

33+
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
34+
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
35+
3336
struct AnyType;
3437
using AnyTypePtr = std::shared_ptr<AnyType>;
3538
// Any is the top of the type hierarchy, all other types are subtypes
@@ -94,25 +97,84 @@ struct SingleElementType : public Type {
9497
TypePtr elem;
9598
};
9699

100+
struct UnionType;
101+
using UnionTypePtr = std::shared_ptr<UnionType>;
102+
struct TORCH_API UnionType : public Type {
103+
friend struct Type;
104+
105+
static const TypeKind Kind = TypeKind::UnionType;
106+
107+
bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;
108+
109+
std::string str() const override;
110+
111+
static UnionTypePtr create(std::vector<TypePtr> reference);
112+
113+
bool operator==(const Type& rhs) const override;
114+
115+
at::ArrayRef<TypePtr> containedTypes() const override {
116+
return types_;
117+
}
118+
119+
// For testing purposes only
120+
at::ArrayRef<TypePtr> getTypes() const {
121+
return types_;
122+
}
123+
124+
TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
125+
return create(contained_types);
126+
}
127+
128+
bool canHoldType(TypePtr type) const;
129+
130+
bool hasFreeVariables() const override {
131+
return has_free_variables_;
132+
}
133+
134+
c10::optional<TypePtr> toOptional() const;
135+
136+
c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;
137+
138+
protected:
139+
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
140+
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
141+
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
142+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
143+
bool has_free_variables_;
144+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
145+
std::vector<TypePtr> types_;
146+
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
147+
bool can_hold_none_;
148+
149+
};
150+
97151
struct OptionalType;
98152
using OptionalTypePtr = std::shared_ptr<OptionalType>;
99-
// This type represents an optional type, for each element type.
100-
// Optional[T] can accept both T and None(nullopt in C++)
153+
// This type represents an optional type. There is one `Optional` for
154+
// each element type. `Optional[T]` can accept both `T` and
155+
// `None`(`c10::nullopt` in C++)
101156
// Subtype hierarchy for Optional:
102-
// 1. Optional[T] <: Optional[R] iff T <: R
103-
// 2. T <: Optional[R] if T <: R
104-
// 3. None <: Optional[T] for all T
105-
struct TORCH_API OptionalType
106-
: public SingleElementType<TypeKind::OptionalType, OptionalType> {
107-
static OptionalTypePtr create(TypePtr element) {
108-
TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr");
109-
// Optional is a union of [None, T], so Optional[[Optional[T]]] ->
110-
// Optional[T]
111-
if (auto opt_ptr = element->cast<OptionalType>()) {
112-
return opt_ptr;
113-
}
114-
return OptionalTypePtr(
115-
new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
157+
// - Optional[T] <: Optional[R] iff T <: R
158+
// - T <: Optional[R] if T <: R
159+
// - None <: Optional[T] for all T
160+
// - Optional[T] == Union[T, None] for all T
161+
struct TORCH_API OptionalType : public UnionType {
162+
static OptionalTypePtr create(TypePtr contained) {
163+
return OptionalTypePtr(new OptionalType(std::move(contained)));
164+
}
165+
166+
static const TypeKind Kind = TypeKind::OptionalType;
167+
168+
friend struct Type;
169+
170+
bool operator==(const Type& rhs) const override;
171+
172+
TypePtr getElementType() const {
173+
return contained_;
174+
}
175+
176+
at::ArrayRef<TypePtr> containedTypes() const override {
177+
return contained_;
116178
}
117179

118180
std::string str() const override {
@@ -127,20 +189,15 @@ struct TORCH_API OptionalType
127189
return create(contained_types[0]);
128190
}
129191

130-
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
131-
if (Type::isSubtypeOfExt(rhs, why_not)) {
132-
return true;
133-
}
134-
if (auto rhs_ = rhs->cast<OptionalType>()) {
135-
return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not);
136-
}
137-
return false;
138-
}
192+
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
193+
139194
// common cast Optional[Tensor] for undefined tensor type
140195
static OptionalTypePtr ofTensor();
141196

142197
private:
143-
OptionalType(TypePtr elem) : SingleElementType(elem) {}
198+
explicit OptionalType(TypePtr contained);
199+
200+
TypePtr contained_;
144201

145202
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
146203
std::stringstream ss;
@@ -908,7 +965,6 @@ struct TORCH_API RRefType
908965
}
909966
};
910967

911-
912968
struct NamedType;
913969
using NamedTypePtr = std::shared_ptr<NamedType>;
914970
using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
@@ -1112,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType {
11121168
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
11131169
};
11141170

1115-
11161171
// the common supertype of all Enums, only used in operator registraion.
11171172
// EnumType <: AnyEnumType for all Enums
11181173
struct AnyEnumType;
@@ -1132,7 +1187,6 @@ struct TORCH_API AnyEnumType : public Type {
11321187
: Type(TypeKind::AnyEnumType) {}
11331188
};
11341189

1135-
11361190
struct NumberType;
11371191
using NumberTypePtr = std::shared_ptr<NumberType>;
11381192
// This type represents a Python number
@@ -1141,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
11411195
// FloatType <: NumberType
11421196
// ComplexType <:NumberType
11431197
struct TORCH_API NumberType : public Type {
1144-
bool operator==(const Type& rhs) const override {
1145-
return rhs.kind() == kind();
1146-
}
1198+
bool operator==(const Type& rhs) const override;
1199+
1200+
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
1201+
11471202
std::string str() const override {
11481203
return "Scalar"; // match what PythonArgParser says for clarity
11491204
}
@@ -1172,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType {
11721227
return "float";
11731228
}
11741229
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
1175-
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
1230+
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
1231+
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
11761232
}
11771233
static const TypeKind Kind = TypeKind::FloatType;
11781234
// global singleton
@@ -1196,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType {
11961252
return "complex";
11971253
}
11981254
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
1199-
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
1255+
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
1256+
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
12001257
}
12011258
static const TypeKind Kind = TypeKind::ComplexType;
12021259
// global singleton
@@ -1220,7 +1277,8 @@ struct TORCH_API IntType : public NumberType {
12201277
return "int";
12211278
}
12221279
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
1223-
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
1280+
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
1281+
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
12241282
}
12251283
static const TypeKind Kind = TypeKind::IntType;
12261284
// global singleton
@@ -1334,12 +1392,8 @@ struct TORCH_API NoneType : public Type {
13341392
std::string str() const override {
13351393
return "NoneType";
13361394
}
1337-
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override {
1338-
if (rhs->kind() == OptionalType::Kind) {
1339-
return true;
1340-
}
1341-
return Type::isSubtypeOfExt(rhs, why_not);
1342-
}
1395+
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override;
1396+
13431397
static const TypeKind Kind = TypeKind::NoneType;
13441398
// global singleton
13451399
static NoneTypePtr get();
@@ -1524,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
15241578
// what is the type, ignoring extra size/shape information?
15251579
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
15261580

1527-
// xxx: be careful with calls because this can be very slow. If calling this on a graph
1528-
// use `EraseShapeInformation` in shape_analysis.h
1581+
// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
1582+
// subtypes as simply "Tensor"; we also create a new version of any
1583+
// container types in which internal Tensors have undergone the same
1584+
// operation. This is used for type comparisons between two Tensor types
1585+
// (`unshapedType` means that we don't falsely return `false` for e.g.
1586+
// Tensors of different dimensions). It's also used in the alias
1587+
// analysis pass.
1588+
// Be careful with calls because this can be very slow. If calling this
1589+
// on a graph, use `EraseShapeInformation` in shape_analysis.h
15291590
inline TypePtr unshapedType(const TypePtr& type) {
15301591
if (type->isSubtypeOf(TensorType::get())) {
15311592
return TensorType::get();
@@ -1569,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
15691630
return *result;
15701631
}
15711632

1572-
// Attempt to find the correct supertype of t1 and t2. If none is found then
1573-
// nullopt will be returned if default_to_any is false, and Any will be returned
1574-
// if it is true. If t1 == t2, or t1 is a type refinement of t2,
1575-
// then t2 will be returned (and vice versa).
1633+
// Attempt to find the correct supertype of the two types `t1` and `t2`.
1634+
// If no supertype is found, then nullopt will be returned if
1635+
// `default_to_union` is false, and `Union[t1, t2]` will be returned
1636+
// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
1637+
// then `t2` will be returned (and vice versa).
1638+
//
15761639
// Two different tensortypes will return dynamic.
1577-
// Currently we chose not to support returning a NumberType for a float & int
1578-
// input because of a lack of operator support for NumberType.
1640+
//
1641+
// Currently we chose not to support returning a NumberType for
1642+
// two types from the set of {FloatType, IntType, ComplexType}, because
1643+
// there is a lack of operator support for NumberType.
1644+
//
15791645
// If `type_hint` is an `InterfaceType`, then we can use that as a
15801646
// potential supertype for `ClassType`s in the list. Otherwise, we have
15811647
// no way to find and use some common interface type
15821648
TORCH_API c10::optional<TypePtr> unifyTypes(
15831649
const TypePtr& t1,
15841650
const TypePtr& t2,
1585-
bool default_to_any = false,
1586-
TypePtr type_hint=nullptr);
1651+
bool default_to_union = false,
1652+
TypePtr type_hint = nullptr);
15871653

15881654
TORCH_API c10::optional<TypePtr> unifyTypeList(
15891655
at::ArrayRef<TypePtr> elements,
15901656
std::ostream& why_not,
1591-
bool default_to_any=false,
1592-
TypePtr type_hint=nullptr);
1657+
bool default_to_union = false,
1658+
TypePtr type_hint = nullptr);
15931659

15941660
namespace detail {
15951661
template <typename T>

aten/src/ATen/core/jit_type_base.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace c10 {
2121
_(DictType) \
2222
_(NumberType) \
2323
_(FloatType) \
24-
_(ComplexType) \
24+
_(ComplexType) \
2525
_(FutureType) \
2626
_(RRefType) \
2727
_(IntType) \
@@ -44,7 +44,8 @@ namespace c10 {
4444
_(ScalarTypeType) \
4545
_(AnyListType) \
4646
_(AnyTupleType) \
47-
_(AnyClassType)
47+
_(AnyClassType) \
48+
_(UnionType)
4849

4950
enum class TypeKind {
5051
#define DEFINE_TYPE(T) T,
@@ -203,7 +204,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
203204
// contained_types
204205
TypePtr withContained(std::vector<TypePtr> contained_types) {
205206
auto current_contained = containedTypes();
206-
AT_ASSERT(current_contained.size() == contained_types.size());
207+
TORCH_INTERNAL_ASSERT(current_contained.size() == contained_types.size());
207208
if (current_contained.equals(contained_types)) {
208209
return shared_from_this();
209210
}

0 commit comments

Comments
 (0)