forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpybind_workspace.cc
82 lines (68 loc) · 2.17 KB
/
pybind_workspace.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "caffe2/core/workspace.h"
#include "caffe2/python/pybind_workspace.h"
namespace caffe2 {
namespace python {
// NOLINTNEXTLINE(modernize-use-equals-default)
BlobFetcherBase::~BlobFetcherBase() {}
C10_DEFINE_TYPED_REGISTRY(
BlobFetcherRegistry,
TypeIdentifier,
BlobFetcherBase,
std::unique_ptr);
// gWorkspace is the pointer to the current workspace. The ownership is kept
// by the gWorkspaces map.
static Workspace* gWorkspace = nullptr;
static std::string gCurrentWorkspaceName;
// gWorkspaces allows us to define and switch between multiple workspaces in
// Python.
static std::map<std::string, std::unique_ptr<Workspace>> gWorkspaces;
Workspace* GetCurrentWorkspace() {
return gWorkspace;
}
void SetCurrentWorkspace(Workspace* workspace) {
gWorkspace = workspace;
}
Workspace* NewWorkspace() {
std::unique_ptr<Workspace> new_workspace(new Workspace());
gWorkspace = new_workspace.get();
return gWorkspace;
}
Workspace* GetWorkspaceByName(const std::string& name) {
if (gWorkspaces.count(name)) {
return gWorkspaces[name].get();
}
return nullptr;
}
std::string GetCurrentWorkspaceName() {
return gCurrentWorkspaceName;
}
void InsertWorkspace(const std::string& name, std::unique_ptr<Workspace> ws) {
gWorkspaces.insert(std::make_pair(name, std::move(ws)));
}
void SwitchWorkspaceInternal(const std::string& name, bool create_if_missing) {
if (gWorkspaces.count(name)) {
gCurrentWorkspaceName = name;
gWorkspace = gWorkspaces[name].get();
return;
}
CAFFE_ENFORCE(create_if_missing);
std::unique_ptr<Workspace> new_workspace(new Workspace());
gWorkspace = new_workspace.get();
gWorkspaces.insert(std::make_pair(name, std::move(new_workspace)));
gCurrentWorkspaceName = name;
}
void ResetWorkspace(Workspace* workspace) {
gWorkspaces[gCurrentWorkspaceName].reset(workspace);
gWorkspace = gWorkspaces[gCurrentWorkspaceName].get();
}
void GetWorkspaceNames(std::vector<std::string>& names) {
for (const auto& kv : gWorkspaces) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
names.emplace_back(kv.first);
}
}
void ClearWorkspaces() {
gWorkspaces.clear();
}
} // namespace python
} // namespace caffe2