19
19
import org .opensearch .ml .common .dataframe .DataFrame ;
20
20
import org .opensearch .ml .common .parameter .MLParameter ;
21
21
import org .opensearch .ml .common .parameter .MLParameterBuilder ;
22
+ import org .opensearch .ml .engine .regression .LinearRegression ;
22
23
23
24
import java .util .ArrayList ;
24
25
import java .util .List ;
29
30
import static org .opensearch .ml .engine .clustering .KMeans .NUM_THREADS ;
30
31
import static org .opensearch .ml .engine .clustering .KMeans .SEED ;
31
32
import static org .opensearch .ml .engine .helper .KMeansHelper .constructKMeansDataFrame ;
33
+ import static org .opensearch .ml .engine .helper .LinearRegressionHelper .constructLinearRegressionPredictionDataFrame ;
32
34
import static org .opensearch .ml .engine .helper .LinearRegressionHelper .constructLinearRegressionTrainDataFrame ;
35
+ import static org .opensearch .ml .engine .regression .LinearRegression .BETA1 ;
36
+ import static org .opensearch .ml .engine .regression .LinearRegression .BETA2 ;
37
+ import static org .opensearch .ml .engine .regression .LinearRegression .EPSILON ;
38
+ import static org .opensearch .ml .engine .regression .LinearRegression .LEARNING_RATE ;
39
+ import static org .opensearch .ml .engine .regression .LinearRegression .OBJECTIVE ;
40
+ import static org .opensearch .ml .engine .regression .LinearRegression .OPTIMISER ;
41
+ import static org .opensearch .ml .engine .regression .LinearRegression .TARGET ;
33
42
34
43
public class MLEngineTest {
35
44
@ Rule
@@ -44,6 +53,14 @@ public void predictKMeans() {
44
53
predictions .forEach (row -> Assert .assertTrue (row .getValue (0 ).intValue () == 0 || row .getValue (0 ).intValue () == 1 ));
45
54
}
46
55
56
+ @ Test
57
+ public void predictLinearRegression () {
58
+ Model model = trainLinearRegressionModel ();
59
+ DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame ();
60
+ DataFrame predictions = MLEngine .predict ("linear_regression" , null , predictionDataFrame , model );
61
+ Assert .assertEquals (2 , predictions .size ());
62
+ }
63
+
47
64
@ Test
48
65
public void trainKMeans () {
49
66
Model model = trainKMeansModel ();
@@ -76,6 +93,13 @@ public void predictUnsupportedAlgorithm() {
76
93
MLEngine .predict (algoName , null , null , null );
77
94
}
78
95
96
+ @ Test
97
+ public void predictWithoutModel () {
98
+ exceptionRule .expect (IllegalArgumentException .class );
99
+ exceptionRule .expectMessage ("No model found for linear regression prediction." );
100
+ MLEngine .predict ("linear_regression" , null , null , null );
101
+ }
102
+
79
103
private Model trainKMeansModel () {
80
104
List <MLParameter > parameters = new ArrayList <>();
81
105
parameters .add (MLParameterBuilder .parameter (SEED , 1L ));
@@ -89,13 +113,13 @@ private Model trainKMeansModel() {
89
113
90
114
private Model trainLinearRegressionModel () {
91
115
List <MLParameter > parameters = new ArrayList <>();
92
- parameters .add (MLParameterBuilder .parameter ("objective" , 0 ));
93
- parameters .add (MLParameterBuilder .parameter ("optimiser" , 5 ));
94
- parameters .add (MLParameterBuilder .parameter ("learning_rate" , 0.01 ));
95
- parameters .add (MLParameterBuilder .parameter ("epsilon" , 1e-6 ));
96
- parameters .add (MLParameterBuilder .parameter ("beta1" , 0.9 ));
97
- parameters .add (MLParameterBuilder .parameter ("beta2" , 0.99 ));
98
- parameters .add (MLParameterBuilder .parameter ("target" , "price" ));
116
+ parameters .add (MLParameterBuilder .parameter (OBJECTIVE , 0 ));
117
+ parameters .add (MLParameterBuilder .parameter (OPTIMISER , 5 ));
118
+ parameters .add (MLParameterBuilder .parameter (LEARNING_RATE , 0.01 ));
119
+ parameters .add (MLParameterBuilder .parameter (EPSILON , 1e-6 ));
120
+ parameters .add (MLParameterBuilder .parameter (BETA1 , 0.9 ));
121
+ parameters .add (MLParameterBuilder .parameter (BETA2 , 0.99 ));
122
+ parameters .add (MLParameterBuilder .parameter (TARGET , "price" ));
99
123
DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame ();
100
124
return MLEngine .train ("linear_regression" , parameters , trainDataFrame );
101
125
}
0 commit comments