@@ -100,19 +100,13 @@ private void validateParameters() {
100
100
@ Override
101
101
public DataFrame predict (DataFrame dataFrame , Model model ) {
102
102
if (model == null ) {
103
- throw new RuntimeException ("No model found for KMeans prediction." );
103
+ throw new IllegalArgumentException ("No model found for KMeans prediction." );
104
104
}
105
105
106
106
List <Prediction <ClusterID >> predictions ;
107
107
MutableDataset <ClusterID > predictionDataset = TribuoUtil .generateDataset (dataFrame , new ClusteringFactory (),
108
108
"KMeans prediction data from opensearch" , TribuoOutputType .CLUSTERID );
109
- KMeansModel kMeansModel = null ;
110
- try {
111
- kMeansModel = (KMeansModel ) ModelSerDeSer .deserialize (model .getContent ());
112
- } catch (Exception e ) {
113
- throw new RuntimeException ("Failed to deserialize model." , e .getCause ());
114
- }
115
-
109
+ KMeansModel kMeansModel = (KMeansModel ) ModelSerDeSer .deserialize (model .getContent ());
116
110
predictions = kMeansModel .predict (predictionDataset );
117
111
118
112
List <Map <String , Object >> listClusterID = new ArrayList <>();
@@ -130,11 +124,7 @@ public Model train(DataFrame dataFrame) {
130
124
Model model = new Model ();
131
125
model .setName ("KMeans" );
132
126
model .setVersion (1 );
133
- try {
134
- model .setContent (ModelSerDeSer .serialize (kMeansModel ));
135
- } catch (Exception e ) {
136
- throw new RuntimeException ("Failed to serialize model." , e .getCause ());
137
- }
127
+ model .setContent (ModelSerDeSer .serialize (kMeansModel ));
138
128
139
129
return model ;
140
130
}
0 commit comments