forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend_init.cpp
191 lines (170 loc) · 7.23 KB
/
backend_init.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#include <torch/csrc/jit/backends/backend_init.h>
#include <pybind11/iostream.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::jit {
// Get all types that are shared in the module hierarchy rooted at \p mod.
std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
// Maintain a set of all TypePtrs.
std::unordered_set<TypePtr> types;
// Maintain another set of TypePtrs that have been encountered more than once.
std::unordered_set<TypePtr> duplicate_types;
// Iterate over all modules in the hierarchy, including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
if (types.count(module_type) > 0) {
duplicate_types.insert(module_type);
}
types.insert(module_type);
}
return duplicate_types;
}
// Selectively lower \p mod to a backend. \p to_backend
// is called to lower modules. \p modules_to_lower contains
// qualified names of submodules of \p mod that should be lowered.
void toBackendSelectiveImpl(
Module& mod,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower,
const std::unordered_set<TypePtr>& duplicate_types) {
// This map will be used later to remap types in ancestor module graphs for
// all lowered submodules.
std::unordered_map<TypePtr, TypePtr> type_remap;
// For each module that should be lowered:
for (const auto& module_to_lower : modules_to_lower) {
// Use QualifiedName to parse the qualified module names.
c10::QualifiedName qual_module_name(module_to_lower);
auto& atoms = qual_module_name.atoms();
// Search through the module hierarchy using the atoms of
// qual_module_name until current points to the module to
// be lowered and parent points to its parent.
Module current = mod;
Module parent;
for (size_t i = 0, e = atoms.size(); i < e; ++i) {
IValue submodule = current.attr(atoms[i]);
if (submodule.isModule()) {
if (i == e - 1) {
parent = current;
}
current = submodule.toModule();
} else {
std::stringstream err;
err << "Attribute named " << atoms[i] << " is not a Module";
throw std::runtime_error(err.str());
}
}
// Check that the parent type is not shared and therefore can be edited.
if (duplicate_types.count(parent.type()) > 0) {
throw py::cast_error(c10::str(
"Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
parent.type()->repr_str(),
" is shared"));
}
// Call to_backend on the module that needs to be lowered. It needs to be
// wrapped before doing so because _to_jit_backend accepts wrapped modules.
// The result needs to be unwrapped in order to access its type below.
auto lowered_submodule =
py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(current))
.attr("_c"));
// Adjust the parent's type so that the type of the submodule matches
// the type of lowered_submodule.
auto parent_type = parent.type();
parent_type->unsafeChangeAttributeType(
atoms.back(), lowered_submodule.type());
parent.setattr(atoms.back(), lowered_submodule._ivalue());
// Record the type mapping from old type -> lowered type.
type_remap[current.type()] = lowered_submodule.type();
}
// Having lowered all of the modules that needed to be lowered, remap types in
// all graphs in the hierarchy so that the graphs all use the new lowered
// type.
auto type_remap_fn = [&type_remap](TypePtr in) {
auto it = type_remap.find(in);
if (it == type_remap.end())
return in;
return it->second;
};
// modules() iterates over all modules in the hierarchy including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
for (auto& fn : module_type->methods()) {
auto method = module.get_method(fn->name());
auto graph = method.graph();
graph->remapTypes(type_remap_fn);
auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
fn->setSchema(new_schema);
}
}
}
Module codegen_func(
const std::string& backend_name,
const Module& orig_module,
const py::dict& method_compile_spec) {
// Represents of a Type of Dict[str, Any].
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
return detail::codegen_backend_module(
backend_name,
orig_module,
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
any_dict_ty);
}
void initJitBackendBindings(PyObject* module) {
// Bind a function for lowering to each JIT backend. The name of the backend
// must be the first argument. For example, to lower a Module to
// "example_backend", declared as
//
// static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
//
// this function must be called like
//
// torch._C._jit_to_backend("example_backend", module, spec)
auto m = py::handle(module).cast<py::module>();
m.def(
"_jit_to_backend",
[=](const std::string& backend_name,
py::handle orig_module,
const py::dict& method_compile_spec) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(codegen_func(
backend_name,
py::cast<Module>(orig_module.attr("_c")),
method_compile_spec));
});
m.def(
"_jit_to_backend_selective",
[=](py::handle orig_module,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
if (auto original_module =
as_module(py::cast<py::object>(orig_module))) {
// Clone the Module to avoid editing types that are shared with
// Modules in other instances outside this hierarchy.
Module& mod = original_module.value();
auto cloned_mod = mod.clone();
// Get all shared module types. Type sharing is only a problem if the
// parent modules of the ones to lower are in this set.
auto shared_types = getSharedModuleTypes(cloned_mod);
toBackendSelectiveImpl(
cloned_mod, to_backend, modules_to_lower, shared_types);
// Wrap the result in a RecursiveScriptModule because that's what
// the caller passed in.
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(cloned_mod);
}
throw py::cast_error(c10::str(
"Object ", py::str(orig_module), " is not a ScriptModule"));
});
}
} // namespace torch::jit