Skip to content

Commit 51f6b94

Browse files
angelayipytorchmergebot
authored andcommitted
[torchbind] Add generic __deepcopy__ method (pytorch#137613)
Summary: Added a generic `__deepcopy__` method which will use the torchbind object's existing `__getattr__` and `__setattr__` to copy the torchbind object. This will later be used in [D64124825](https://www.internalfb.com/diff/D64124825) Differential Revision: D64124826 Pull Request resolved: pytorch#137613 Approved by: https://github.com/ydwu4, https://github.com/zou3519
1 parent 282e638 commit 51f6b94

File tree

3 files changed

+110
-13
lines changed

3 files changed

+110
-13
lines changed

test/cpp/jit/test_custom_class_registrations.cpp

+38-12
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,39 @@ struct TensorQueue : torch::CustomClassHolder {
145145
}
146146
}
147147

148-
c10::Dict<std::string, at::Tensor> serialize() const {
149-
c10::Dict<std::string, at::Tensor> dict;
150-
dict.insert(std::string("init_tensor"), init_tensor_);
151-
const std::string key = "queue";
152-
dict.insert(
153-
key + "/size", torch::tensor(static_cast<int64_t>(queue_.size())));
154-
for (const auto index : c10::irange(queue_.size())) {
155-
dict.insert(key + "/" + std::to_string(index), queue_[index]);
148+
std::tuple<
149+
std::tuple<std::string, at::Tensor>,
150+
std::tuple<std::string, std::vector<at::Tensor>>>
151+
serialize() {
152+
return std::tuple(
153+
std::tuple("init_tensor", this->init_tensor_.clone()),
154+
std::tuple("queue", this->clone_queue()));
155+
}
156+
157+
static c10::intrusive_ptr<TensorQueue> deserialize(
158+
std::tuple<
159+
std::tuple<std::string, at::Tensor>,
160+
std::tuple<std::string, std::vector<at::Tensor>>> flattened) {
161+
TORCH_CHECK(std::tuple_size<decltype(flattened)>::value == 2);
162+
163+
auto init_tensor_tuple = std::get<0>(flattened);
164+
TORCH_CHECK(std::tuple_size<decltype(init_tensor_tuple)>::value == 2);
165+
TORCH_CHECK(std::get<0>(init_tensor_tuple) == std::string("init_tensor"));
166+
167+
c10::intrusive_ptr<TensorQueue> queue =
168+
c10::make_intrusive<TensorQueue>(std::get<1>(init_tensor_tuple));
169+
170+
auto queue_tuple = std::get<1>(flattened);
171+
TORCH_CHECK(std::tuple_size<decltype(queue_tuple)>::value == 2);
172+
TORCH_CHECK(std::get<0>(queue_tuple) == std::string("queue"));
173+
174+
for (auto& value : std::get<1>(queue_tuple)) {
175+
queue->push(value);
156176
}
157-
return dict;
177+
178+
return queue;
158179
}
180+
159181
// Push the element to the rear of queue.
160182
// Lock is added for thread safe.
161183
void push(at::Tensor x) {
@@ -639,13 +661,17 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
639661
.def_pickle(
640662
// __getstate__
641663
[](const c10::intrusive_ptr<TensorQueue>& self)
642-
-> c10::Dict<std::string, at::Tensor> {
664+
-> std::tuple<
665+
std::tuple<std::string, at::Tensor>,
666+
std::tuple<std::string, std::vector<at::Tensor>>> {
643667
return self->serialize();
644668
},
645669
// __setstate__
646-
[](c10::Dict<std::string, at::Tensor> data)
670+
[](std::tuple<
671+
std::tuple<std::string, at::Tensor>,
672+
std::tuple<std::string, std::vector<at::Tensor>>> data)
647673
-> c10::intrusive_ptr<TensorQueue> {
648-
return c10::make_intrusive<TensorQueue>(std::move(data));
674+
return TensorQueue::deserialize(data);
649675
});
650676
}
651677

test/export/test_torchbind.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Owner(s): ["oncall: export"]
22

3-
3+
import copy
44
import unittest
55

66
import torch
@@ -1028,6 +1028,30 @@ def forward(self, token, tq, x):
10281028
return (tq,)""", # noqa: B950
10291029
)
10301030

1031+
def test_deepcopy(self):
1032+
tq = torch.classes._TorchScriptTesting._TensorQueue(
1033+
torch.empty(
1034+
0,
1035+
).fill_(-1)
1036+
)
1037+
tq_0 = copy.deepcopy(tq)
1038+
tq.push(torch.zeros(2, 2))
1039+
tq.push(torch.ones(2, 2))
1040+
tq_1 = copy.deepcopy(tq)
1041+
tq.push(torch.ones(2, 2) * 2)
1042+
self.assertEqual(tq_0.size(), 0)
1043+
self.assertEqual(tq_1.size(), 2)
1044+
self.assertEqual(tq.size(), 3)
1045+
1046+
foo = torch.classes._TorchScriptTesting._Foo(1, 2)
1047+
foo_0 = copy.deepcopy(foo)
1048+
foo.increment(1)
1049+
foo_1 = copy.deepcopy(foo)
1050+
foo.increment(1)
1051+
self.assertEqual(foo_0.add(1), 3)
1052+
self.assertEqual(foo_1.add(1), 5)
1053+
self.assertEqual(foo.add(1), 7)
1054+
10311055

10321056
class TestCompileTorchbind(TestCase):
10331057
def setUp(self):

torch/csrc/jit/python/script_init.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,53 @@ void initJitScriptBindings(PyObject* module) {
866866
// Similar to Tensor's `__hash__`, which is `id()`.
867867
return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
868868
})
869+
.def(
870+
"__deepcopy__",
871+
[](const Object& self, const py::dict& memo) {
872+
if (auto getstate_method = self.find_method("__getstate__")) {
873+
auto object_state = toPyObject((*getstate_method)(Stack{}));
874+
875+
if (auto qualname = self.type()->name()) {
876+
auto class_type = getCustomClass(qualname->qualifiedName());
877+
auto self = Object(c10::ivalue::Object::create(
878+
c10::StrongTypePtr(
879+
std::shared_ptr<torch::jit::CompilationUnit>(),
880+
class_type),
881+
1));
882+
883+
if (auto setstate_method =
884+
self.find_method("__setstate__")) {
885+
auto setstate_schema =
886+
setstate_method->function().getSchema();
887+
TORCH_INTERNAL_ASSERT(
888+
setstate_schema.arguments().size() == 2,
889+
"__setstate__ method for class ",
890+
class_type->repr_str(),
891+
" must have exactly 2 arguments!");
892+
auto state_type =
893+
setstate_schema.arguments().at(1).type();
894+
(*setstate_method)(
895+
Stack{toIValue(object_state, state_type)});
896+
return self;
897+
}
898+
std::stringstream err;
899+
err << "Tried to deepcopy object ";
900+
if (auto qualname = class_type->name()) {
901+
err << qualname->qualifiedName() << " ";
902+
}
903+
err << "which does not have a __setstate__ method defined!";
904+
throw std::runtime_error(err.str());
905+
}
906+
}
907+
908+
std::stringstream err;
909+
err << "Tried to deepcopy object ";
910+
if (auto qualname = self.type()->name()) {
911+
err << qualname->qualifiedName() << " ";
912+
}
913+
err << "which does not have a __getstate__ method defined!";
914+
throw std::runtime_error(err.str());
915+
})
869916
.def(py::pickle(
870917
[](const Object& self)
871918
-> std::tuple<py::object, std::string> { // __getstate__

0 commit comments

Comments
 (0)