Skip to content

Commit 58a0c3e

Browse files
Use Adagrad optimiser for Linear regression by default (opensearch-project#3291) (opensearch-project#3303)
* Use AdaGrad optimiser by default in Liner Resgression Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Added issue link in the code comment as a reference. Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> * Apply Spotless Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> --------- Signed-off-by: rithin-pullela-aws <rithinp@amazon.com> (cherry picked from commit f323141) Co-authored-by: Rithin Pullela <rithinp@amazon.com>
1 parent 98b2696 commit 58a0c3e

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
public class LinearRegression implements Trainable, Predictable {
5353
public static final String VERSION = "1.0.0";
5454
private static final LinearRegressionParams.ObjectiveType DEFAULT_OBJECTIVE_TYPE = LinearRegressionParams.ObjectiveType.SQUARED_LOSS;
55-
private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.SIMPLE_SGD;
55+
private static final LinearRegressionParams.OptimizerType DEFAULT_OPTIMIZER_TYPE = LinearRegressionParams.OptimizerType.ADA_GRAD;
5656
private static final double DEFAULT_LEARNING_RATE = 0.01;
5757
// Momentum
5858
private static final double DEFAULT_MOMENTUM_FACTOR = 0;
@@ -134,15 +134,15 @@ private void createOptimiser() {
134134
break;
135135
}
136136
switch (optimizerType) {
137+
case SIMPLE_SGD:
138+
optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
139+
break;
137140
case LINEAR_DECAY_SGD:
138141
optimiser = SGD.getLinearDecaySGD(learningRate, momentumFactor, momentum);
139142
break;
140143
case SQRT_DECAY_SGD:
141144
optimiser = SGD.getSqrtDecaySGD(learningRate, momentumFactor, momentum);
142145
break;
143-
case ADA_GRAD:
144-
optimiser = new AdaGrad(learningRate, epsilon);
145-
break;
146146
case ADA_DELTA:
147147
optimiser = new AdaDelta(momentumFactor, epsilon);
148148
break;
@@ -153,8 +153,9 @@ private void createOptimiser() {
153153
optimiser = new RMSProp(learningRate, momentumFactor, epsilon, decayRate);
154154
break;
155155
default:
156-
// Use default SGD with a constant learning rate.
157-
optimiser = SGD.getSimpleSGD(learningRate, momentumFactor, momentum);
156+
// Use AdaGrad by default, reference issue:
157+
// https://github.com/opensearch-project/ml-commons/issues/3210#issuecomment-2556119802
158+
optimiser = new AdaGrad(learningRate, epsilon);
158159
break;
159160
}
160161
}

0 commit comments

Comments
 (0)