forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpickle.cpp
125 lines (104 loc) · 3.38 KB
/
pickle.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <ATen/core/ivalue.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
namespace jit {
// These are both defined in `torch/serialization.py`
const char* torch_save_magic_number =
"\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19";
uint16_t protocol_version = 1001;
void pickle(
std::function<void(const char* data_start, size_t data_len)> writer,
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
Pickler pickler(std::move(writer), tensor_table);
pickler.protocol();
pickler.pushIValue(ivalue);
pickler.stop();
}
std::vector<char> pickle(
const IValue& ivalue,
std::vector<at::Tensor>* tensor_table) {
std::vector<char> data;
pickle(
[&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
},
ivalue,
tensor_table);
return data;
}
// This has to live here instead of the C++ API to mirror torch.save since the
// mobile build excludes the C++ API
std::vector<char> pickle_save(const at::IValue& ivalue) {
std::vector<char> data;
auto writer = [&](const char* bytes, size_t len) {
data.insert(data.end(), bytes, bytes + len);
};
jit::Pickler pickler(writer, /*tensor_table=*/nullptr);
// Output data to match torch.save, see torch/serialization.py for details
// Magic number (0x1950a86a20f9469cfc6c)
pickler.protocol();
pickler.pushLong(torch_save_magic_number);
pickler.stop();
// Protocol Version
pickler.protocol();
pickler.pushInt(protocol_version);
pickler.stop();
// sys_info, this isn't actually used in de-serialization so we can leave this
// one empty
pickler.protocol();
pickler.pushEmptyDict();
pickler.stop();
jit::Pickler data_pickler(writer, /*tensor_table=*/nullptr);
data_pickler.protocol();
data_pickler.pushIValue(ivalue);
data_pickler.stop();
auto writeable_tensors = data_pickler.tensorData();
std::vector<at::IValue> keys;
keys.reserve(writeable_tensors.size());
std::vector<at::TypePtr> types(writeable_tensors.size(), at::StringType::get());
for (size_t i = 0; i < writeable_tensors.size(); i++) {
keys.emplace_back(c10::to_string(i));
}
auto keys_tuple = at::ivalue::Tuple::create(keys);
jit::pickle(writer, keys_tuple);
for (const auto& tensor_data : writeable_tensors) {
const char* addr = tensor_data.data();
size_t numel = tensor_data.numel();
writer(reinterpret_cast<const char*>(&numel), sizeof(numel));
writer(addr, tensor_data.sizeInBytes());
}
return data;
}
IValue unpickle(
std::function<size_t(char*, size_t)> reader,
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table) {
Unpickler unpickler(
std::move(reader), std::move(class_resolver), tensor_table);
return unpickler.parse_ivalue();
}
IValue unpickle(
const char* data,
size_t size,
ClassResolver class_resolver,
const std::vector<at::Tensor>* tensor_table) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) -> size_t {
if (bytes_read >= size) {
return 0;
}
len = std::min(size - bytes_read, len);
// Copy len bytes into buffer
const char* start = data + bytes_read;
std::memcpy(buffer, start, len);
bytes_read += len;
return len;
},
std::move(class_resolver),
tensor_table);
}
} // namespace jit
} // namespace torch