Skip to content

Commit a9f5e72

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch] Remove reference_cast in make_boxed_from_unboxed_functor (pytorch#51319)
Summary: Pull Request resolved: pytorch#51319 We were going out of our way to accommodate `IValue::to<Tensor>` returning a copy of the inner Tensor. `IValue::toTensor` is capable of returning a reference without copying, so if we use it directly, we can allow kernels that want to take `Tensor &` to do so! As a bonus, we get reduced build times. ghstack-source-id: 121378961 Test Plan: Rely on CI for correctness. Profiled build time with -ftime-trace for RegisterCPU.cpp using an extracted build invocation. Before: P168244900 After: P168245014 Note reduced time spent compiling make_boxed_from_unboxed_functor. I also ran the AdIndexer benchmark (https://fb.quip.com/ztERAYjuzdlr) with static runtime disabled and batch size 1 to see how big the effect on boxed call performance was (any kernels that take `Tensor&` or `const Tensor&` should now actually save a refcount bump). Looks like it was roughly 1% better: Before: 124-125 usec/iter After: 122-123 usec/iter Reviewed By: bhosmer Differential Revision: D26138549 fbshipit-source-id: b0f830527da360c542c815bef2f7e1692615b32a
1 parent c442776 commit a9f5e72

File tree

3 files changed

+53
-34
lines changed

3 files changed

+53
-34
lines changed

aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h

+48-29
Original file line numberDiff line numberDiff line change
@@ -249,29 +249,68 @@ namespace impl {
249249

250250
// ivalue_to_arg
251251

252+
template<class T>
253+
struct decay_if_not_tensor final {
254+
using type = std::decay_t<T>;
255+
};
256+
257+
template<>
258+
struct decay_if_not_tensor<at::Tensor&> final {
259+
using type = at::Tensor&;
260+
};
261+
262+
template<>
263+
struct decay_if_not_tensor<const at::Tensor&> final {
264+
using type = const at::Tensor&;
265+
};
266+
252267
template<class T, bool AllowDeprecatedTypes>
253268
struct ivalue_to_arg final {
254-
static T call(IValue&& v) {
269+
static decltype(auto) call(IValue& v) {
255270
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
256271
return std::move(v).to<T>();
257272
}
258273
};
259274

275+
// The following two specializations take advantage of specialized
276+
// `toTensor()` overloads on IValue to avoid copying.
277+
template<bool AllowDeprecatedTypes>
278+
struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
279+
// We cannot use the default implementation if they asked for a
280+
// `at::Tensor&` because it moves from the IValue, so it can't get
281+
// an lvalue reference.
282+
static at::Tensor& call(IValue& v) {
283+
// Tensor& is valid, don't bother asserting
284+
return v.toTensor();
285+
}
286+
};
287+
288+
template<bool AllowDeprecatedTypes>
289+
struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
290+
// We should not use the default implementation if they asked for
291+
// a `const at::Tensor&` because it moves from the IValue and they
292+
// didn't ask for that.
293+
static const at::Tensor& call(IValue& v) {
294+
// const Tensor& is valid, don't bother asserting
295+
return v.toTensor();
296+
}
297+
};
298+
260299
template<class T, bool AllowDeprecatedTypes>
261300
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
262301
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
263302
// to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
264-
static std::vector<T> call(IValue&& v) {
265-
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(std::move(v));
303+
static std::vector<T> call(IValue& v) {
304+
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
266305
}
267306
};
268307
template<class T, bool AllowDeprecatedTypes>
269308
struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
270309
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
271310
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
272311
// to optional<ArrayRef<T>>.
273-
static OptionalArray<T> call(IValue&& v) {
274-
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(std::move(v));
312+
static OptionalArray<T> call(IValue& v) {
313+
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
275314
}
276315
};
277316

@@ -296,19 +335,6 @@ namespace impl {
296335
}
297336
};
298337

299-
// reference_cast allows casting references, e.g. T&& to T&:
300-
// T make_t() {}
301-
// T& v = reference_cast<T&>(make_t()); // make_t() returns a T&& which is cast to T&.
302-
// If the target is a non-reference value, then it gets moved:
303-
// T make_t() {}
304-
// T v = reference_cast<T>(make_t()); // no copies involved
305-
// The first example actually also shows why reference_cast is usually a very bad idea. v now is a lvalue
306-
// reference to a dead temporary. Use with caution!
307-
template<class T, class U>
308-
T reference_cast(U&& t) {
309-
return std::forward<T>(t);
310-
}
311-
312338
// wrap_kernel_functor_unboxed_
313339

314340
template<class KernelFunctor, class OpSignature>
@@ -363,21 +389,14 @@ namespace impl {
363389
call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
364390
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
365391

366-
/*
367-
* For ops that take "Tensor&" as an argument, ivalue_to_arg would still return a "Tensor" by value
368-
* and C++ doesn't allow us to call (*functor) with a temporary "Tensor" when it expects "Tensor&".
369-
* We use reference_cast to explicitly cast our temporary to a "Tensor&" and make it pass the compiler.
370-
* Even though usually dangerous, this is ok here because temporaries live until the end of the statement.
371-
* TODO We should remove reference_cast once kernels don't take "Tensor&" arguments anymore
372-
*/
373392
// We're explicitly filtering out DispatchKeySet from the argument list.
374393
// Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
375394
// We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
376395
// See Note [Plumbing Keys Through The Dispatcher] for the background.
377-
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet, reference_cast<ArgTypes>(
378-
ivalue_to_arg<std::decay_t<ArgTypes>, AllowDeprecatedTypes>::call(
379-
std::move(torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))
380-
))...);
396+
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet,
397+
ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call(
398+
torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
399+
)...);
381400
}
382401

383402
template<class Functor, bool AllowDeprecatedTypes>

aten/src/ATen/core/op_registration/op_registration_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC
639639
expectThrows<c10::Error>([] {
640640
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
641641
.kernel(DispatchKey::CPU, [] (const int64_t&) {})
642-
.kernel(DispatchKey::CUDA, [] (int64_t&) {}));
642+
.kernel(DispatchKey::CUDA, [] (int64_t) {}));
643643
}, "Mismatch in kernel C++ signatures");
644644
}
645645

torch/custom_class_detail.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ call_torchbind_method_from_stack(
7979
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
8080
// TODO We shouldn't use c10::impl stuff directly here. We should use the KernelFunction API instead.
8181
return (functor)(c10::impl::ivalue_to_arg<
82-
std::remove_cv_t<std::remove_reference_t<
82+
typename c10::impl::decay_if_not_tensor<
8383
c10::guts::typelist::
84-
element_t<ivalue_arg_indices, IValueArgTypes>>>,
85-
AllowDeprecatedTypes>::call(std::move(
86-
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
84+
element_t<ivalue_arg_indices, IValueArgTypes>>::type,
85+
AllowDeprecatedTypes>::call(
86+
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args))...);
8787
}
8888

8989
template <class Functor, bool AllowDeprecatedTypes>

0 commit comments

Comments
 (0)