@@ -15,9 +15,11 @@ namespace caffe {
15
15
template <typename Dtype>
16
16
class SGDSolver : public Solver <Dtype> {
17
17
public:
18
- explicit SGDSolver (const SolverParameter& param)
19
- : Solver<Dtype>(param) { PreSolve (); }
20
- explicit SGDSolver (const string& param_file)
18
+ explicit SGDSolver (const SolverParameter& param,
19
+ Solver<Dtype> *root_solver = NULL )
20
+ : Solver<Dtype>(param, root_solver) { PreSolve (); }
21
+ explicit SGDSolver (const string& param_file,
22
+ Solver<Dtype> *root_solver = NULL )
21
23
: Solver<Dtype>(param_file) { PreSolve (); }
22
24
virtual inline const char * type () const { return " SGD" ; }
23
25
@@ -48,10 +50,12 @@ class SGDSolver : public Solver<Dtype> {
48
50
template <typename Dtype>
49
51
class NesterovSolver : public SGDSolver <Dtype> {
50
52
public:
51
- explicit NesterovSolver (const SolverParameter& param)
52
- : SGDSolver<Dtype>(param) {}
53
- explicit NesterovSolver (const string& param_file)
54
- : SGDSolver<Dtype>(param_file) {}
53
+ explicit NesterovSolver (const SolverParameter& param,
54
+ Solver<Dtype> *root_solver = NULL )
55
+ : SGDSolver<Dtype>(param, root_solver) {}
56
+ explicit NesterovSolver (const string& param_file,
57
+ Solver<Dtype> *root_solver = NULL )
58
+ : SGDSolver<Dtype>(param_file, root_solver) {}
55
59
virtual inline const char * type () const { return " Nesterov" ; }
56
60
57
61
protected:
@@ -63,10 +67,14 @@ class NesterovSolver : public SGDSolver<Dtype> {
63
67
template <typename Dtype>
64
68
class AdaGradSolver : public SGDSolver <Dtype> {
65
69
public:
66
- explicit AdaGradSolver (const SolverParameter& param)
67
- : SGDSolver<Dtype>(param) { constructor_sanity_check (); }
68
- explicit AdaGradSolver (const string& param_file)
69
- : SGDSolver<Dtype>(param_file) { constructor_sanity_check (); }
70
+ explicit AdaGradSolver (const SolverParameter& param,
71
+ Solver<Dtype> *root_solver = NULL )
72
+ : SGDSolver<Dtype>(param, root_solver)
73
+ { constructor_sanity_check (); }
74
+ explicit AdaGradSolver (const string& param_file,
75
+ Solver<Dtype> *root_solver = NULL )
76
+ : SGDSolver<Dtype>(param_file, root_solver)
77
+ { constructor_sanity_check (); }
70
78
virtual inline const char * type () const { return " AdaGrad" ; }
71
79
72
80
protected:
@@ -83,10 +91,14 @@ class AdaGradSolver : public SGDSolver<Dtype> {
83
91
template <typename Dtype>
84
92
class RMSPropSolver : public SGDSolver <Dtype> {
85
93
public:
86
- explicit RMSPropSolver (const SolverParameter& param)
87
- : SGDSolver<Dtype>(param) { constructor_sanity_check (); }
88
- explicit RMSPropSolver (const string& param_file)
89
- : SGDSolver<Dtype>(param_file) { constructor_sanity_check (); }
94
+ explicit RMSPropSolver (const SolverParameter& param,
95
+ Solver<Dtype> *root_solver = NULL )
96
+ : SGDSolver<Dtype>(param, root_solver)
97
+ { constructor_sanity_check (); }
98
+ explicit RMSPropSolver (const string& param_file,
99
+ Solver<Dtype> *root_solver = NULL )
100
+ : SGDSolver<Dtype>(param_file, root_solver)
101
+ { constructor_sanity_check (); }
90
102
virtual inline const char * type () const { return " RMSProp" ; }
91
103
92
104
protected:
@@ -106,10 +118,12 @@ class RMSPropSolver : public SGDSolver<Dtype> {
106
118
template <typename Dtype>
107
119
class AdaDeltaSolver : public SGDSolver <Dtype> {
108
120
public:
109
- explicit AdaDeltaSolver (const SolverParameter& param)
110
- : SGDSolver<Dtype>(param) { AdaDeltaPreSolve (); }
111
- explicit AdaDeltaSolver (const string& param_file)
112
- : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve (); }
121
+ explicit AdaDeltaSolver (const SolverParameter& param,
122
+ Solver<Dtype> *root_solver = NULL )
123
+ : SGDSolver<Dtype>(param, root_solver) { AdaDeltaPreSolve (); }
124
+ explicit AdaDeltaSolver (const string& param_file,
125
+ Solver<Dtype> *root_solver = NULL )
126
+ : SGDSolver<Dtype>(param_file, root_solver) { AdaDeltaPreSolve (); }
113
127
virtual inline const char * type () const { return " AdaDelta" ; }
114
128
115
129
protected:
@@ -130,10 +144,12 @@ class AdaDeltaSolver : public SGDSolver<Dtype> {
130
144
template <typename Dtype>
131
145
class AdamSolver : public SGDSolver <Dtype> {
132
146
public:
133
- explicit AdamSolver (const SolverParameter& param)
134
- : SGDSolver<Dtype>(param) { AdamPreSolve ();}
135
- explicit AdamSolver (const string& param_file)
136
- : SGDSolver<Dtype>(param_file) { AdamPreSolve (); }
147
+ explicit AdamSolver (const SolverParameter& param,
148
+ Solver<Dtype> *root_solver = NULL )
149
+ : SGDSolver<Dtype>(param, root_solver) { AdamPreSolve ();}
150
+ explicit AdamSolver (const string& param_file,
151
+ Solver<Dtype> *root_solver = NULL )
152
+ : SGDSolver<Dtype>(param_file, root_solver) { AdamPreSolve (); }
137
153
virtual inline const char * type () const { return " Adam" ; }
138
154
139
155
protected:
0 commit comments