Skip to content

Commit b2aa63f

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch] Fix return value of IValue::to for Tensor/String (pytorch#51463)
Summary: Pull Request resolved: pytorch#51463 We can make the return type of the `to()` template match the return type of toFoo() by using the same technique we use for `list_element_to_const_ref`. Also simplifies `list_element_to_const_ref`. ghstack-source-id: 121363468 Test Plan: CI built and ran AdIndexer benchmark w/ batch size 1 under perf stat --repeat 5 to make sure it didn't regress Reviewed By: bhosmer Differential Revision: D26163848 fbshipit-source-id: b8563263b9f9fa5311c7d7cedc89e28bc5badda0
1 parent a9f5e72 commit b2aa63f

File tree

5 files changed

+46
-30
lines changed

5 files changed

+46
-30
lines changed

aten/src/ATen/core/List.h

+4-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <ATen/core/ivalue_to.h>
34
#include <c10/macros/Macros.h>
45
#include <c10/util/TypeTraits.h>
56
#include <c10/util/TypeList.h>
@@ -55,26 +56,16 @@ bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
5556

5657
template<class T>
5758
struct ListElementConstReferenceTraits {
58-
// In the general case, we cannot expose a true const reference to
59-
// the contents of an IValue, so we copy.
60-
using const_reference = T;
61-
};
62-
63-
template<>
64-
struct ListElementConstReferenceTraits<std::string> {
65-
using const_reference = const std::string&;
59+
// In the general case, we use IValue::to().
60+
using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
6661
};
6762

63+
// There is no to() overload for c10::optional<std::string>.
6864
template<>
6965
struct ListElementConstReferenceTraits<c10::optional<std::string>> {
7066
using const_reference = c10::optional<std::reference_wrapper<const std::string>>;
7167
};
7268

73-
template<>
74-
struct ListElementConstReferenceTraits<at::Tensor> {
75-
using const_reference = const at::Tensor&;
76-
};
77-
7869
template<class T, class Iterator>
7970
class ListElementReference final {
8071
public:

aten/src/ATen/core/List_inl.h

-12
Original file line numberDiff line numberDiff line change
@@ -167,24 +167,12 @@ list_element_to_const_ref(const IValue& element) {
167167
return element.template to<T>();
168168
}
169169

170-
template<>
171-
inline typename ListElementConstReferenceTraits<std::string>::const_reference
172-
list_element_to_const_ref<std::string>(const IValue& element) {
173-
return element.toStringRef();
174-
}
175-
176170
template<>
177171
inline typename ListElementConstReferenceTraits<c10::optional<std::string>>::const_reference
178172
list_element_to_const_ref<c10::optional<std::string>>(const IValue& element) {
179173
return element.toOptionalStringRef();
180174
}
181175

182-
template<>
183-
inline typename ListElementConstReferenceTraits<at::Tensor>::const_reference
184-
list_element_to_const_ref<at::Tensor>(const IValue& element) {
185-
return element.toTensor();
186-
}
187-
188176
} // namespace impl
189177

190178
template<class T>

aten/src/ATen/core/ivalue.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/core/TensorBody.h>
44
#include <ATen/core/blob.h>
5+
#include <ATen/core/ivalue_to.h>
56
#include <c10/util/C++17.h>
67
#include <c10/util/intrusive_ptr.h>
78
#include <torch/csrc/WindowsTorchApiMacro.h>
@@ -781,7 +782,7 @@ struct TORCH_API IValue final {
781782
template <typename T>
782783
T to() &&;
783784
template <typename T>
784-
T to() const&;
785+
typename c10::detail::ivalue_to_const_ref_overload_return<T>::type to() const&;
785786

786787
// ToOptional: convert a IValue to the Optional obj that accepts both T and
787788
// None

aten/src/ATen/core/ivalue_inl.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -765,13 +765,13 @@ inline const ivalue::Object& IValue::toObjectRef() const {
765765
// toX method to IValue. These named methods are much more discoverable
766766
// than the to templated function.
767767

768-
#define DEFINE_TO(type, method_name) \
768+
#define DEFINE_TO(T, method_name) \
769769
template <> \
770-
inline type IValue::to<type>()&& { \
770+
inline T IValue::to<T>()&& { \
771771
return std::move(*this).method_name(); \
772772
} \
773773
template <> \
774-
inline type IValue::to<type>() const& { \
774+
inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \
775775
return this->method_name(); \
776776
}
777777

@@ -1014,7 +1014,7 @@ inline T IValue::to() && {
10141014
}
10151015

10161016
template <typename T>
1017-
inline T IValue::to() const& {
1017+
inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& {
10181018
return generic_to(*this, _fake_type<T>{});
10191019
}
10201020

aten/src/ATen/core/ivalue_to.h

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
namespace at {
6+
class Tensor;
7+
} // namespace at
8+
9+
namespace c10 {
10+
struct IValue;
11+
namespace detail {
12+
// Determine the return type of `IValue::to() const &`. It's a const
13+
// reference when possible and a copy otherwise. It is in this
14+
// separate header so that List can use it as well.
15+
template<typename T>
16+
struct ivalue_to_const_ref_overload_return {
17+
using type = T;
18+
};
19+
20+
template<>
21+
struct ivalue_to_const_ref_overload_return<at::Tensor> {
22+
using type = const at::Tensor&;
23+
};
24+
25+
template<>
26+
struct ivalue_to_const_ref_overload_return<std::string> {
27+
using type = const std::string&;
28+
};
29+
30+
template<>
31+
struct ivalue_to_const_ref_overload_return<IValue> {
32+
using type = const IValue&;
33+
};
34+
35+
} // namespace detail
36+
} // namespace c10

0 commit comments

Comments
 (0)