Skip to content

Commit d0db23e

Browse files
bwastifacebook-github-bot
authored andcommitted
Add distributed annotations
Summary: Annotations for DAI Reviewed By: duc0 Differential Revision: D9805867 fbshipit-source-id: 9ce2d9f3984817510ec8362a281f39878aad55e7
1 parent de11fe0 commit d0db23e

File tree

5 files changed

+194
-57
lines changed

5 files changed

+194
-57
lines changed

caffe2/opt/annotations.cc

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "caffe2/opt/annotations.h"
2+
3+
namespace caffe2 {
4+
5+
using namespace nom::repr;
6+
7+
void Caffe2Annotation::setOperatorDef(const caffe2::OperatorDef& opDef) {
8+
OpDef = opDef;
9+
OpDefExists = true;
10+
}
11+
12+
bool Caffe2Annotation::hasOperatorDef() const {
13+
return OpDefExists;
14+
}
15+
16+
const caffe2::OperatorDef& Caffe2Annotation::getOperatorDef() const {
17+
CAFFE_ENFORCE(
18+
OpDefExists,
19+
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
20+
return OpDef;
21+
}
22+
caffe2::OperatorDef* Caffe2Annotation::getMutableOperatorDef() {
23+
CAFFE_ENFORCE(
24+
OpDefExists,
25+
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
26+
return &OpDef;
27+
}
28+
29+
// Distributed annotations
30+
void Caffe2Annotation::setDevice(std::string device) {
31+
Device = device;
32+
}
33+
const std::string Caffe2Annotation::getDevice() const {
34+
return Device;
35+
}
36+
37+
void Caffe2Annotation::setDeviceType(int device) {
38+
DeviceType = device;
39+
}
40+
int Caffe2Annotation::getDeviceType() const {
41+
return DeviceType;
42+
}
43+
44+
void Caffe2Annotation::setParallelization(
45+
Caffe2Annotation::ParallelizationScheme s,
46+
int num) {
47+
parallelization_scheme_ = s;
48+
parallelization_ = num;
49+
}
50+
51+
Caffe2Annotation::ParallelizationScheme
52+
Caffe2Annotation::getParallelizationScheme() const {
53+
return parallelization_scheme_;
54+
}
55+
56+
int Caffe2Annotation::getParallelization() const {
57+
return parallelization_;
58+
}
59+
60+
void Caffe2Annotation::setKeyNode(NNGraph::NodeRef n) {
61+
key_node_ = n;
62+
}
63+
const NNGraph::NodeRef& Caffe2Annotation::getKeyNode() const {
64+
CAFFE_ENFORCE(key_node_, "No key node has been annotated");
65+
return key_node_;
66+
}
67+
void Caffe2Annotation::setLengthNode(NNGraph::NodeRef n) {
68+
length_node_ = n;
69+
}
70+
const NNGraph::NodeRef& Caffe2Annotation::getLengthNode() const {
71+
CAFFE_ENFORCE(length_node_, "No length node has been annotated");
72+
return length_node_;
73+
}
74+
75+
void Caffe2Annotation::setComponentLevels(std::vector<std::string> components) {
76+
component_levels_ = components;
77+
}
78+
std::vector<std::string> Caffe2Annotation::getComponentLevels() const {
79+
return component_levels_;
80+
}
81+
82+
bool Caffe2Annotation::classof(const Annotation* A) {
83+
return A->getKind() == AnnotationKind::Caffe2;
84+
}
85+
86+
} // namespace caffe2

caffe2/opt/annotations.h

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
#include "caffe2/core/common.h"
4+
#include "caffe2/core/logging.h"
5+
#include "caffe2/proto/caffe2_pb.h"
6+
#include "nomnigraph/Representations/NeuralNet.h"
7+
8+
namespace caffe2 {
9+
10+
class Caffe2Annotation : public nom::repr::Annotation {
11+
public:
12+
Caffe2Annotation() : Annotation(AnnotationKind::Caffe2) {}
13+
Caffe2Annotation(std::string device)
14+
: Annotation(AnnotationKind::Caffe2), Device(device) {}
15+
virtual ~Caffe2Annotation() {}
16+
17+
void setOperatorDef(const caffe2::OperatorDef& opDef);
18+
bool hasOperatorDef() const;
19+
const caffe2::OperatorDef& getOperatorDef() const;
20+
caffe2::OperatorDef* getMutableOperatorDef();
21+
22+
// Distributed annotations
23+
void setDevice(std::string device);
24+
const std::string getDevice() const;
25+
void setDeviceType(int device);
26+
int getDeviceType() const;
27+
28+
enum class ParallelizationScheme {
29+
none,
30+
split_by_batch,
31+
split_by_length,
32+
shard,
33+
shard_by_number
34+
};
35+
void setParallelization(ParallelizationScheme, int num = -1);
36+
ParallelizationScheme getParallelizationScheme() const;
37+
int getParallelization() const;
38+
39+
void setKeyNode(nom::repr::NNGraph::NodeRef);
40+
const nom::repr::NNGraph::NodeRef& getKeyNode() const;
41+
void setLengthNode(nom::repr::NNGraph::NodeRef);
42+
const nom::repr::NNGraph::NodeRef& getLengthNode() const;
43+
44+
void setComponentLevels(std::vector<std::string> components);
45+
std::vector<std::string> getComponentLevels() const;
46+
47+
static bool classof(const Annotation* A);
48+
49+
private:
50+
std::string Device = "";
51+
caffe2::OperatorDef OpDef;
52+
bool OpDefExists = false;
53+
54+
// Distributed annotations
55+
int DeviceType = caffe2::DeviceTypeProto::PROTO_CPU;
56+
ParallelizationScheme parallelization_scheme_ = ParallelizationScheme::none;
57+
int parallelization_ = -1;
58+
nom::repr::NNGraph::NodeRef key_node_ = nullptr;
59+
nom::repr::NNGraph::NodeRef length_node_ = nullptr;
60+
std::vector<std::string> component_levels_;
61+
};
62+
63+
} // namespace caffe2

caffe2/opt/converter.h

+1-55
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "caffe2/core/common.h"
55
#include "caffe2/core/logging.h"
6+
#include "caffe2/opt/annotations.h"
67
#include "caffe2/proto/caffe2_pb.h"
78
#include "nomnigraph/Graph/Graph.h"
89
#include "nomnigraph/Representations/ControlFlow.h"
@@ -12,62 +13,7 @@
1213

1314
namespace caffe2 {
1415

15-
class Caffe2Annotation : public nom::repr::Annotation {
16-
public:
17-
Caffe2Annotation() : Annotation(AnnotationKind::Caffe2) {}
18-
Caffe2Annotation(std::string device)
19-
: Annotation(AnnotationKind::Caffe2), Device(device) {}
20-
virtual ~Caffe2Annotation() {}
21-
22-
void setDevice(std::string device) {
23-
Device = device;
24-
}
25-
const std::string getDevice() const {
26-
return Device;
27-
}
28-
29-
void setDeviceType(int device) {
30-
DeviceType = device;
31-
}
32-
int getDeviceType() const {
33-
return DeviceType;
34-
}
35-
36-
void setOperatorDef(const caffe2::OperatorDef& opDef) {
37-
OpDef = opDef;
38-
OpDefExists = true;
39-
}
40-
41-
bool hasOperatorDef() const {
42-
return OpDefExists;
43-
}
44-
45-
const caffe2::OperatorDef& getOperatorDef() const {
46-
CAFFE_ENFORCE(
47-
OpDefExists,
48-
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
49-
return OpDef;
50-
}
51-
caffe2::OperatorDef* getMutableOperatorDef() {
52-
CAFFE_ENFORCE(
53-
OpDefExists,
54-
"OperatorDef was never set. Use Caffe2Annotation::setOperatorDef.");
55-
return &OpDef;
56-
}
57-
58-
static bool classof(const Annotation *A) {
59-
return A->getKind() == AnnotationKind::Caffe2;
60-
}
61-
62-
private:
63-
std::string Device = "";
64-
caffe2::OperatorDef OpDef;
65-
bool OpDefExists = false;
66-
int DeviceType = caffe2::DeviceTypeProto::PROTO_CPU;
67-
};
68-
6916
CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict = false);
70-
7117
CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
7218

7319
// Pass in an oldNet to copy all the attributes of that network.

caffe2/python/nomnigraph_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,23 @@ def test_annotation_from_graph(self):
223223
node.setAnnotation(annot)
224224
new_annot = node.getAnnotation()
225225
assert new_annot.getDeviceType() == 7
226+
227+
def test_distribute_annotations(self):
228+
nn = ng.NNModule()
229+
key = nn.dataFlow.createNode(ng.NeuralNetData("key"))
230+
length = nn.dataFlow.createNode(ng.NeuralNetData("length"))
231+
node = nn.dataFlow.createNode(ng.NeuralNetOperator("TestOp"))
232+
233+
annot = ng.Annotation()
234+
annot.setKeyNode(key)
235+
annot.setLengthNode(length)
236+
annot.setComponentLevels(["", "test", "woot"])
237+
238+
node.setAnnotation(annot)
239+
240+
new_annot = node.getAnnotation()
241+
#assert new_annot.getLengthNode() == length
242+
assert new_annot.getKeyNode() == key
243+
assert len(new_annot.getComponentLevels()) == 3
244+
assert new_annot.getComponentLevels()[0] == ""
245+
assert new_annot.getComponentLevels()[2] == "woot"

caffe2/python/pybind_state_nomni.cc

+24-2
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,19 @@ void addNomnigraphMethods(pybind11::module& m) {
199199
return nn::get<nom::repr::Tensor>(n);
200200
},
201201
py::return_value_policy::reference_internal)
202+
.def_property(
203+
"annotation",
204+
[](NNGraph::NodeRef n) { return getOrAddCaffe2Annotation(n); },
205+
[](NNGraph::NodeRef n, Caffe2Annotation annot) {
206+
auto* nnOp = nn::get<NeuralNetOperator>(n);
207+
nnOp->setAnnotation(
208+
nom::util::make_unique<Caffe2Annotation>(annot));
209+
},
210+
py::return_value_policy::copy)
202211
.def(
203212
"getAnnotation",
204-
[](NNGraph::NodeRef n) { return getOrAddCaffe2Annotation(n); })
213+
[](NNGraph::NodeRef n) { return getOrAddCaffe2Annotation(n); },
214+
py::return_value_policy::copy)
205215
.def(
206216
"setAnnotation",
207217
[](NNGraph::NodeRef n, Caffe2Annotation annot) {
@@ -327,7 +337,19 @@ void addNomnigraphMethods(pybind11::module& m) {
327337
.def("setDevice", &Caffe2Annotation::setDevice)
328338
.def("getDevice", &Caffe2Annotation::getDevice)
329339
.def("setDeviceType", &Caffe2Annotation::setDeviceType)
330-
.def("getDeviceType", &Caffe2Annotation::getDeviceType);
340+
.def("getDeviceType", &Caffe2Annotation::getDeviceType)
341+
.def("setKeyNode", &Caffe2Annotation::setKeyNode)
342+
.def(
343+
"getKeyNode",
344+
&Caffe2Annotation::getKeyNode,
345+
py::return_value_policy::reference)
346+
.def("setLengthNode", &Caffe2Annotation::setLengthNode)
347+
.def(
348+
"getLengthNode",
349+
&Caffe2Annotation::getLengthNode,
350+
py::return_value_policy::reference)
351+
.def("setComponentLevels", &Caffe2Annotation::setComponentLevels)
352+
.def("getComponentLevels", &Caffe2Annotation::getComponentLevels);
331353
}
332354

333355
REGISTER_PYBIND_ADDITION(addNomnigraphMethods);

0 commit comments

Comments
 (0)