Skip to content

Commit 8e03fc4

Browse files
jngz-esjackiehanyang
authored andcommitted
add linear regression predict (opensearch-project#35)
* add linear regression predict
1 parent d239c25 commit 8e03fc4

File tree

6 files changed

+89
-9
lines changed

6 files changed

+89
-9
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public static DataFrame predict(String algoName, List<MLParameter> parameters, D
3333
case MLAlgoNames.KMEANS:
3434
KMeans kMeans = new KMeans(parameters);
3535
return kMeans.predict(dataFrame, model);
36+
case MLAlgoNames.LINEAR_REGRESSION:
37+
LinearRegression linearRegression = new LinearRegression(parameters);
38+
return linearRegression.predict(dataFrame, model);
3639
default:
3740
throw new IllegalArgumentException("Unsupported algorithm: " + algoName);
3841
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/regression/LinearRegression.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
package org.opensearch.ml.engine.regression;
1414

1515
import org.opensearch.ml.common.dataframe.DataFrame;
16+
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
1617
import org.opensearch.ml.common.parameter.MLParameter;
1718
import org.opensearch.ml.engine.MLAlgo;
1819
import org.opensearch.ml.engine.Model;
1920
import org.opensearch.ml.engine.contants.TribuoOutputType;
2021
import org.opensearch.ml.engine.utils.ModelSerDeSer;
2122
import org.opensearch.ml.engine.utils.TribuoUtil;
2223
import org.tribuo.MutableDataset;
24+
import org.tribuo.Prediction;
2325
import org.tribuo.math.StochasticGradientOptimiser;
2426
import org.tribuo.math.optimisers.AdaDelta;
2527
import org.tribuo.math.optimisers.AdaGrad;
@@ -34,7 +36,10 @@
3436
import org.tribuo.regression.sgd.objectives.Huber;
3537
import org.tribuo.regression.sgd.objectives.SquaredLoss;
3638

39+
import java.util.ArrayList;
40+
import java.util.Collections;
3741
import java.util.List;
42+
import java.util.Map;
3843

3944
public class LinearRegression implements MLAlgo {
4045
public static final String OBJECTIVE = "objective";
@@ -200,8 +205,18 @@ private void validateParameters() {
200205

201206
@Override
202207
public DataFrame predict(DataFrame dataFrame, Model model) {
203-
//TODO
204-
throw new RuntimeException("Unsupported predict.");
208+
if (model == null) {
209+
throw new IllegalArgumentException("No model found for linear regression prediction.");
210+
}
211+
212+
org.tribuo.Model<Regressor> regressionModel = (org.tribuo.Model<Regressor>) ModelSerDeSer.deserialize(model.getContent());
213+
MutableDataset<Regressor> predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(),
214+
"Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR);
215+
List<Prediction<Regressor>> predictions = regressionModel.predict(predictionDataset);
216+
List<Map<String, Object>> listPrediction = new ArrayList<>();
217+
predictions.forEach(e -> listPrediction.add(Collections.singletonMap("Prediction", e.getOutput().getValues()[0])));
218+
219+
return DataFrameBuilder.load(listPrediction);
205220
}
206221

207222
@Override

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/TribuoUtil.java

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame
6868
case CLUSTERID:
6969
example = new ArrayExample<T>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
7070
break;
71+
case REGRESSOR:
72+
//Create single dimension tribuo regressor with name DIM-0 and value double NaN.
73+
example = new ArrayExample<T>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
74+
break;
7175
default:
7276
throw new IllegalArgumentException("unknown type:" + outputType);
7377
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

+31-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.opensearch.ml.common.dataframe.DataFrame;
2020
import org.opensearch.ml.common.parameter.MLParameter;
2121
import org.opensearch.ml.common.parameter.MLParameterBuilder;
22+
import org.opensearch.ml.engine.regression.LinearRegression;
2223

2324
import java.util.ArrayList;
2425
import java.util.List;
@@ -29,7 +30,15 @@
2930
import static org.opensearch.ml.engine.clustering.KMeans.NUM_THREADS;
3031
import static org.opensearch.ml.engine.clustering.KMeans.SEED;
3132
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
33+
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
3234
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
35+
import static org.opensearch.ml.engine.regression.LinearRegression.BETA1;
36+
import static org.opensearch.ml.engine.regression.LinearRegression.BETA2;
37+
import static org.opensearch.ml.engine.regression.LinearRegression.EPSILON;
38+
import static org.opensearch.ml.engine.regression.LinearRegression.LEARNING_RATE;
39+
import static org.opensearch.ml.engine.regression.LinearRegression.OBJECTIVE;
40+
import static org.opensearch.ml.engine.regression.LinearRegression.OPTIMISER;
41+
import static org.opensearch.ml.engine.regression.LinearRegression.TARGET;
3342

3443
public class MLEngineTest {
3544
@Rule
@@ -44,6 +53,14 @@ public void predictKMeans() {
4453
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
4554
}
4655

56+
@Test
57+
public void predictLinearRegression() {
58+
Model model = trainLinearRegressionModel();
59+
DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame();
60+
DataFrame predictions = MLEngine.predict("linear_regression", null, predictionDataFrame, model);
61+
Assert.assertEquals(2, predictions.size());
62+
}
63+
4764
@Test
4865
public void trainKMeans() {
4966
Model model = trainKMeansModel();
@@ -76,6 +93,13 @@ public void predictUnsupportedAlgorithm() {
7693
MLEngine.predict(algoName, null, null, null);
7794
}
7895

96+
@Test
97+
public void predictWithoutModel() {
98+
exceptionRule.expect(IllegalArgumentException.class);
99+
exceptionRule.expectMessage("No model found for linear regression prediction.");
100+
MLEngine.predict("linear_regression", null, null, null);
101+
}
102+
79103
private Model trainKMeansModel() {
80104
List<MLParameter> parameters = new ArrayList<>();
81105
parameters.add(MLParameterBuilder.parameter(SEED, 1L));
@@ -89,13 +113,13 @@ private Model trainKMeansModel() {
89113

90114
private Model trainLinearRegressionModel() {
91115
List<MLParameter> parameters = new ArrayList<>();
92-
parameters.add(MLParameterBuilder.parameter("objective", 0));
93-
parameters.add(MLParameterBuilder.parameter("optimiser", 5));
94-
parameters.add(MLParameterBuilder.parameter("learning_rate", 0.01));
95-
parameters.add(MLParameterBuilder.parameter("epsilon", 1e-6));
96-
parameters.add(MLParameterBuilder.parameter("beta1", 0.9));
97-
parameters.add(MLParameterBuilder.parameter("beta2", 0.99));
98-
parameters.add(MLParameterBuilder.parameter("target", "price"));
116+
parameters.add(MLParameterBuilder.parameter(OBJECTIVE, 0));
117+
parameters.add(MLParameterBuilder.parameter(OPTIMISER, 5));
118+
parameters.add(MLParameterBuilder.parameter(LEARNING_RATE, 0.01));
119+
parameters.add(MLParameterBuilder.parameter(EPSILON, 1e-6));
120+
parameters.add(MLParameterBuilder.parameter(BETA1, 0.9));
121+
parameters.add(MLParameterBuilder.parameter(BETA2, 0.99));
122+
parameters.add(MLParameterBuilder.parameter(TARGET, "price"));
99123
DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame();
100124
return MLEngine.train("linear_regression", parameters, trainDataFrame);
101125
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/helper/LinearRegressionHelper.java

+13
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,17 @@ public static DataFrame constructLinearRegressionTrainDataFrame() {
4141

4242
return DataFrameBuilder.load(columnMetas, rows);
4343
}
44+
45+
public static DataFrame constructLinearRegressionPredictionDataFrame() {
46+
double[] feet = new double[]{5000, 5500};
47+
String[] columnNames = new String[]{"feet"};
48+
ColumnMeta[] columnMetas = Arrays.stream(columnNames).map(e -> new ColumnMeta(e, ColumnType.DOUBLE)).toArray(ColumnMeta[]::new);
49+
List<Map<String, Object>> rows = new ArrayList<>();
50+
for (int i=0; i<feet.length; ++i) {
51+
Map<String, Object> row = new HashMap<>();
52+
row.put("feet", feet[i]);
53+
rows.add(row);
54+
}
55+
return DataFrameBuilder.load(columnMetas, rows);
56+
}
4457
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/regression/LinearRegressionTest.java

+21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.ArrayList;
2626
import java.util.List;
2727

28+
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
2829
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
2930
import static org.opensearch.ml.engine.regression.LinearRegression.BETA1;
3031
import static org.opensearch.ml.engine.regression.LinearRegression.BETA2;
@@ -41,6 +42,7 @@ public class LinearRegressionTest {
4142

4243
private List<MLParameter> parameters = new ArrayList<>();
4344
private DataFrame trainDataFrame;
45+
private DataFrame predictionDataFrame;
4446

4547
@Before
4648
public void setUp() {
@@ -51,6 +53,25 @@ public void setUp() {
5153
parameters.add(MLParameterBuilder.parameter(BETA1, 0.9));
5254
parameters.add(MLParameterBuilder.parameter(BETA2, 0.99));
5355
trainDataFrame = constructLinearRegressionTrainDataFrame();
56+
predictionDataFrame = constructLinearRegressionPredictionDataFrame();
57+
}
58+
59+
@Test
60+
public void predict() {
61+
parameters.add(MLParameterBuilder.parameter(TARGET, "price"));
62+
LinearRegression regression = new LinearRegression(parameters);
63+
Model model = regression.train(trainDataFrame);
64+
DataFrame predictions = regression.predict(predictionDataFrame, model);
65+
Assert.assertEquals(2, predictions.size());
66+
}
67+
68+
@Test
69+
public void predictWithoutModel() {
70+
exceptionRule.expect(IllegalArgumentException.class);
71+
exceptionRule.expectMessage("No model found for linear regression prediction.");
72+
parameters.add(MLParameterBuilder.parameter(TARGET, "price"));
73+
LinearRegression regression = new LinearRegression(parameters);
74+
regression.predict(predictionDataFrame, null);
5475
}
5576

5677
@Test

0 commit comments

Comments
 (0)