@@ -249,29 +249,68 @@ namespace impl {
249
249
250
250
// ivalue_to_arg
251
251
252
+ template <class T >
253
+ struct decay_if_not_tensor final {
254
+ using type = std::decay_t <T>;
255
+ };
256
+
257
+ template <>
258
+ struct decay_if_not_tensor <at::Tensor&> final {
259
+ using type = at::Tensor&;
260
+ };
261
+
262
+ template <>
263
+ struct decay_if_not_tensor <const at::Tensor&> final {
264
+ using type = const at::Tensor&;
265
+ };
266
+
252
267
template <class T , bool AllowDeprecatedTypes>
253
268
struct ivalue_to_arg final {
254
- static T call (IValue& & v) {
269
+ static decltype ( auto ) call(IValue& v) {
255
270
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
256
271
return std::move (v).to <T>();
257
272
}
258
273
};
259
274
275
+ // The following two specializations take advantage of specialized
276
+ // `toTensor()` overloads on IValue to avoid copying.
277
+ template <bool AllowDeprecatedTypes>
278
+ struct ivalue_to_arg <at::Tensor&, AllowDeprecatedTypes> final {
279
+ // We cannot use the default implementation if they asked for a
280
+ // `at::Tensor&` because it moves from the IValue, so it can't get
281
+ // an lvalue reference.
282
+ static at::Tensor& call (IValue& v) {
283
+ // Tensor& is valid, don't bother asserting
284
+ return v.toTensor ();
285
+ }
286
+ };
287
+
288
+ template <bool AllowDeprecatedTypes>
289
+ struct ivalue_to_arg <const at::Tensor&, AllowDeprecatedTypes> final {
290
+ // We should not use the default implementation if they asked for
291
+ // a `const at::Tensor&` because it moves from the IValue and they
292
+ // didn't ask for that.
293
+ static const at::Tensor& call (IValue& v) {
294
+ // const Tensor& is valid, don't bother asserting
295
+ return v.toTensor ();
296
+ }
297
+ };
298
+
260
299
template <class T , bool AllowDeprecatedTypes>
261
300
struct ivalue_to_arg <ArrayRef<T>, AllowDeprecatedTypes> final {
262
301
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
263
302
// to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
264
- static std::vector<T> call (IValue&& v) {
265
- return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call (std::move (v) );
303
+ static std::vector<T> call (IValue& v) {
304
+ return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call (v );
266
305
}
267
306
};
268
307
template <class T , bool AllowDeprecatedTypes>
269
308
struct ivalue_to_arg <optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
270
309
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
271
310
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
272
311
// to optional<ArrayRef<T>>.
273
- static OptionalArray<T> call (IValue&& v) {
274
- return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call (std::move (v) );
312
+ static OptionalArray<T> call (IValue& v) {
313
+ return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call (v );
275
314
}
276
315
};
277
316
@@ -296,19 +335,6 @@ namespace impl {
296
335
}
297
336
};
298
337
299
- // reference_cast allows casting references, e.g. T&& to T&:
300
- // T make_t() {}
301
- // T& v = reference_cast<T&>(make_t()); // make_t() returns a T&& which is cast to T&.
302
- // If the target is a non-reference value, then it gets moved:
303
- // T make_t() {}
304
- // T v = reference_cast<T>(make_t()); // no copies involved
305
- // The first example actually also shows why reference_cast is usually a very bad idea. v now is a lvalue
306
- // reference to a dead temporary. Use with caution!
307
- template <class T , class U >
308
- T reference_cast (U&& t) {
309
- return std::forward<T>(t);
310
- }
311
-
312
338
// wrap_kernel_functor_unboxed_
313
339
314
340
template <class KernelFunctor , class OpSignature >
@@ -363,21 +389,14 @@ namespace impl {
363
389
call_functor_with_args_from_stack_ (OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
364
390
(void )(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
365
391
366
- /*
367
- * For ops that take "Tensor&" as an argument, ivalue_to_arg would still return a "Tensor" by value
368
- * and C++ doesn't allow us to call (*functor) with a temporary "Tensor" when it expects "Tensor&".
369
- * We use reference_cast to explicitly cast our temporary to a "Tensor&" and make it pass the compiler.
370
- * Even though usually dangerous, this is ok here because temporaries live until the end of the statement.
371
- * TODO We should remove reference_cast once kernels don't take "Tensor&" arguments anymore
372
- */
373
392
// We're explicitly filtering out DispatchKeySet from the argument list.
374
393
// Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
375
394
// We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
376
395
// See Note [Plumbing Keys Through The Dispatcher] for the background.
377
- return wrap_kernel_functor_unboxed<Functor>::call (functor, dispatchKeySet, reference_cast<ArgTypes>(
378
- ivalue_to_arg<std:: decay_t <ArgTypes>, AllowDeprecatedTypes>::call (
379
- std::move ( torch::jit::peek (*stack, ivalue_arg_indices, sizeof ...(ivalue_arg_indices) ))
380
- )) ...);
396
+ return wrap_kernel_functor_unboxed<Functor>::call (functor, dispatchKeySet,
397
+ ivalue_to_arg<typename decay_if_not_tensor <ArgTypes>::type , AllowDeprecatedTypes>::call (
398
+ torch::jit::peek (*stack, ivalue_arg_indices, sizeof ...(ivalue_arg_indices))
399
+ )...);
381
400
}
382
401
383
402
template <class Functor , bool AllowDeprecatedTypes>
0 commit comments