Skip to content

Commit ac8d220

Browse files
chengjunlupytorchmergebot
authored andcommitted
Add __torch_function__ override protocol supporting to some factory functions
## Motivation Add `__torch_function__` override protocol supporting to the factory functions in defined in pytorch_torch_funcions_manual.cpp. ## Solution By moving the PythonArg parser from the tensor_new.cpp and add the torch function handle dispatching for these API in `torch` name space. as_tensor sparse_coo_tensor _sparse_coo_tensor_unsafe sparce_csr_tensor _sparce_csr_tensor_unsafe. Pull Request resolved: pytorch#75639 Approved by: https://github.com/ezyang
1 parent 7c90171 commit ac8d220

File tree

4 files changed

+152
-57
lines changed

4 files changed

+152
-57
lines changed

test/test_overrides.py

+15
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,19 @@ def __torch_function__(self, *args, **kwargs):
11161116
self.assertEqual(torch.split(None, [2]), -1) # python side
11171117
self.assertEqual(bar(x), -1)
11181118

1119+
def test_factory_override(self):
1120+
class A(TorchFunctionMode):
1121+
def __torch_function__(self, *args, **kwargs):
1122+
return -1
1123+
1124+
with torch.overrides.push_torch_function_mode(A):
1125+
self.assertEqual(torch.tensor([1]), -1)
1126+
self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1127+
self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1128+
self.assertEqual(torch._sparse_coo_tensor_unsafe(1, 1, (1, 1)), -1)
1129+
self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1)
1130+
self.assertEqual(torch.as_tensor([1]), -1)
1131+
11191132
def test_enable_torch_function_mode_with_tensor_subclass(self):
11201133
x = torch.randn(1)
11211134
with torch.overrides.enable_torch_function_mode(SubTensor):
@@ -1322,5 +1335,7 @@ class B(torch.Tensor):
13221335
self.assertEqual(called, 2)
13231336

13241337

1338+
1339+
13251340
if __name__ == '__main__':
13261341
run_tests()

torch/csrc/autograd/python_torch_functions_manual.cpp

+88-6
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,21 @@ static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject*
359359
static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
360360
{
361361
HANDLE_TH_ERRORS
362+
static PythonArgParser parser({
363+
"as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
364+
});
365+
366+
ParsedArgs<3> parsed_args;
367+
auto r = parser.parse(args, kwargs, parsed_args);
368+
if (r.has_torch_function()) {
369+
return handle_torch_function(
370+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
371+
}
362372
jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
363-
return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
373+
return THPVariable_Wrap(torch::utils::as_tensor(
374+
torch::tensors::get_default_dispatch_key(),
375+
torch::tensors::get_default_scalar_type(),
376+
r));
364377
END_HANDLE_TH_ERRORS
365378
}
366379

@@ -397,32 +410,87 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject*
397410
static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
398411
{
399412
HANDLE_TH_ERRORS
413+
static PythonArgParser parser({
414+
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
415+
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
416+
});
417+
418+
ParsedArgs<9> parsed_args;
419+
auto r = parser.parse(args, kwargs, parsed_args);
420+
if (r.has_torch_function()) {
421+
return handle_torch_function(
422+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
423+
}
400424
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
401-
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
425+
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(
426+
torch::tensors::get_default_dispatch_key(),
427+
torch::tensors::get_default_scalar_type(),
428+
r));
402429
END_HANDLE_TH_ERRORS
403430
}
404431

405432
static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
406433
{
407434
HANDLE_TH_ERRORS
435+
static PythonArgParser parser({
436+
"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
437+
});
438+
439+
ParsedArgs<7> parsed_args;
440+
auto r = parser.parse(args, kwargs, parsed_args);
441+
if (r.has_torch_function()) {
442+
return handle_torch_function(
443+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
444+
}
408445
jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
409-
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
446+
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(
447+
torch::tensors::get_default_dispatch_key(),
448+
torch::tensors::get_default_scalar_type(),
449+
r));
410450
END_HANDLE_TH_ERRORS
411451
}
412452

413453
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
414454
{
415455
HANDLE_TH_ERRORS
456+
static PythonArgParser parser({
457+
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
458+
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
459+
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
460+
});
461+
462+
ParsedArgs<6> parsed_args;
463+
auto r = parser.parse(args, kwargs, parsed_args);
464+
if (r.has_torch_function()) {
465+
return handle_torch_function(
466+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
467+
}
416468
jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
417-
return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
469+
return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(
470+
torch::tensors::get_default_dispatch_key(),
471+
torch::tensors::get_default_scalar_type(),
472+
r));
418473
END_HANDLE_TH_ERRORS
419474
}
420475

421476
static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
422477
{
423478
HANDLE_TH_ERRORS
479+
static PythonArgParser parser({
480+
"_sparse_coo_tensor_unsafe(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
481+
});
482+
483+
ParsedArgs<6> parsed_args;
484+
auto r = parser.parse(args, kwargs, parsed_args);
485+
if (r.has_torch_function()) {
486+
return handle_torch_function(
487+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
488+
}
424489
jit::tracer::warn("torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
425-
return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
490+
return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(
491+
torch::tensors::get_default_dispatch_key(),
492+
torch::tensors::get_default_scalar_type(),
493+
r));
426494
END_HANDLE_TH_ERRORS
427495
}
428496

@@ -431,8 +499,22 @@ static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject
431499
static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
432500
{
433501
HANDLE_TH_ERRORS
502+
static PythonArgParser parser({
503+
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
504+
});
505+
506+
constexpr int ctor_num_args = 6;
507+
ParsedArgs<ctor_num_args> parsed_args;
508+
auto r = parser.parse(args, kwargs, parsed_args);
509+
if (r.has_torch_function()) {
510+
return handle_torch_function(
511+
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
512+
}
434513
jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
435-
return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
514+
return THPVariable_Wrap(torch::utils::tensor_ctor(
515+
torch::tensors::get_default_dispatch_key(),
516+
torch::tensors::get_default_scalar_type(),
517+
r));
436518
END_HANDLE_TH_ERRORS
437519
}
438520

torch/csrc/utils/tensor_new.cpp

+24-45
Original file line numberDiff line numberDiff line change
@@ -592,16 +592,13 @@ Tensor indexing_tensor_from_data(
592592
}
593593
}
594594

595-
Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
595+
Tensor sparse_csr_tensor_ctor(
596+
c10::DispatchKey dispatch_key,
597+
at::ScalarType scalar_type,
598+
PythonArgs& r) {
596599
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
597600
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
598-
static PythonArgParser parser({
599-
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
600-
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
601-
});
602601
const int NUM_ARGS = 9, CROW_INDICES_ARG = 0, COL_INDICES_ARG = 1, VALUES_ARG = 2;
603-
ParsedArgs<NUM_ARGS> parsed_args;
604-
auto r = parser.parse(args, kwargs, parsed_args);
605602
auto safe_get_attr_string = [](PyObject *o, const char *attr_name) -> PyObject* {
606603
// Clear error indicator if attribute does not exists.
607604
// Otherwise subsequent Python C API calls might return bogus values.
@@ -667,7 +664,10 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
667664
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
668665
}
669666

670-
Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
667+
Tensor _sparse_csr_tensor_unsafe_ctor(
668+
c10::DispatchKey dispatch_key,
669+
at::ScalarType scalar_type,
670+
PythonArgs& r) {
671671
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
672672
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
673673
enum {
@@ -680,12 +680,6 @@ Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarT
680680
ARG_REQUIRES_GRAD,
681681
ARGS_COUNT
682682
};
683-
static PythonArgParser parser({
684-
"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
685-
});
686-
687-
ParsedArgs<ARGS_COUNT> parsed_args;
688-
auto r = parser.parse(args, kwargs, parsed_args);
689683
bool type_inference = r.isNone(ARG_TYPE);
690684
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
691685
const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
@@ -726,17 +720,12 @@ Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarT
726720
// "this needs to be CUDA" and indices would be allocated on the wrong tensor.
727721
// Options is more right and gets this correct.
728722

729-
Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
723+
Tensor sparse_coo_tensor_ctor(
724+
c10::DispatchKey dispatch_key,
725+
at::ScalarType scalar_type,
726+
PythonArgs& r) {
730727
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
731728
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
732-
static PythonArgParser parser({
733-
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
734-
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
735-
"sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
736-
});
737-
738-
ParsedArgs<6> parsed_args;
739-
auto r = parser.parse(args, kwargs, parsed_args);
740729
if (r.idx == 0) {
741730
bool type_inference = r.isNone(2);
742731
const auto inferred_options = typeIdWithDefault(r, 3, dispatch_key);
@@ -773,7 +762,10 @@ Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
773762
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
774763
}
775764

776-
Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
765+
Tensor _sparse_coo_tensor_unsafe_ctor(
766+
c10::DispatchKey dispatch_key,
767+
at::ScalarType scalar_type,
768+
PythonArgs& r) {
777769
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
778770
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
779771
enum {
@@ -785,12 +777,6 @@ Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarT
785777
ARG_REQUIRES_GRAD,
786778
ARGS_COUNT
787779
};
788-
static PythonArgParser parser({
789-
"_sparse_coo_tensor_unsafe(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
790-
});
791-
792-
ParsedArgs<ARGS_COUNT> parsed_args;
793-
auto r = parser.parse(args, kwargs, parsed_args);
794780
bool type_inference = r.isNone(ARG_TYPE);
795781
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
796782
const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
@@ -846,14 +832,10 @@ void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarT
846832
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, r.intlist(3));
847833
}
848834

849-
Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
850-
static PythonArgParser parser({
851-
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",
852-
});
853-
854-
constexpr int ctor_num_args = 6;
855-
ParsedArgs<ctor_num_args> parsed_args;
856-
auto r = parser.parse(args, kwargs, parsed_args);
835+
Tensor tensor_ctor(
836+
c10::DispatchKey dispatch_key,
837+
at::ScalarType scalar_type,
838+
PythonArgs& r) {
857839
if (r.idx == 0) {
858840
PyObject* data = r.pyobject(0);
859841
if (THPVariable_Check(data)) {
@@ -886,14 +868,11 @@ Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, Py
886868
throw std::runtime_error("tensor(): invalid arguments");
887869
}
888870

889-
Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
871+
Tensor as_tensor(
872+
c10::DispatchKey dispatch_key,
873+
at::ScalarType scalar_type,
874+
PythonArgs& r) {
890875
// TODO: add requires_grad once we decide on semantics for sharing data.
891-
static PythonArgParser parser({
892-
"as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
893-
});
894-
895-
ParsedArgs<3> parsed_args;
896-
auto r = parser.parse(args, kwargs, parsed_args);
897876
if (r.idx == 0) {
898877
bool type_inference = r.isNone(1);
899878
return internal_new_from_data(

torch/csrc/utils/tensor_new.h

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <torch/csrc/python_headers.h>
4+
#include <torch/csrc/utils/python_arg_parser.h>
45

56
#include <ATen/core/Tensor.h>
67

@@ -14,14 +15,32 @@ at::Tensor indexing_tensor_from_data(
1415
at::ScalarType scalar_type,
1516
c10::optional<at::Device> device,
1617
PyObject* data);
17-
at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
18-
at::Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
18+
at::Tensor sparse_coo_tensor_ctor(
19+
c10::DispatchKey dispatch_key,
20+
at::ScalarType scalar_type,
21+
PythonArgs& r);
22+
at::Tensor _sparse_coo_tensor_unsafe_ctor(
23+
c10::DispatchKey dispatch_key,
24+
at::ScalarType scalar_type,
25+
PythonArgs& r);
1926
void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
20-
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
21-
at::Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
27+
at::Tensor sparse_csr_tensor_ctor(
28+
c10::DispatchKey dispatch_key,
29+
at::ScalarType scalar_type,
30+
PythonArgs& r);
31+
at::Tensor _sparse_csr_tensor_unsafe_ctor(
32+
c10::DispatchKey dispatch_key,
33+
at::ScalarType scalar_type,
34+
PythonArgs& r);
2235
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
23-
at::Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
24-
at::Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
36+
at::Tensor tensor_ctor(
37+
c10::DispatchKey dispatch_key,
38+
at::ScalarType scalar_type,
39+
PythonArgs& r);
40+
at::Tensor as_tensor(
41+
c10::DispatchKey dispatch_key,
42+
at::ScalarType scalar_type,
43+
PythonArgs& r);
2544
at::Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
2645
at::Tensor new_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
2746
at::Tensor tensor_frombuffer(PyObject* buffer, at::ScalarType dtype, int64_t count, int64_t offset, bool requires_grad);

0 commit comments

Comments
 (0)