Skip to content

Commit 24a4c9c

Browse files
spbjssAlex
authored andcommitted
Fix the risks found by PenTest (opensearch-project#76)
* Create JvmService instance on demand. Signed-off-by: Alex <pengsun@amazon.com> * Move the ml_parameters from XContent to the request parameters to avoid the conflict with search XContent input. Signed-off-by: Alex <pengsun@amazon.com> * Fix the security risks found by PenTest. 1. unhandled 500 server error. 2. Insecure Deserialization Signed-off-by: Alex <pengsun@amazon.com> * Remove unnecessory '*' from the welcome list of model deserializer. Signed-off-by: Alex <pengsun@amazon.com> Co-authored-by: Alex <pengsun@amazon.com>
1 parent c2133bc commit 24a4c9c

File tree

5 files changed

+49
-13
lines changed

5 files changed

+49
-13
lines changed

ml-algorithms/build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ dependencies {
2626
compile group: 'org.reflections', name: 'reflections', version: '0.9.12'
2727
compile group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.0.2'
2828
compile group: 'org.tribuo', name: 'tribuo-regression-sgd', version: '4.0.2'
29+
compile group: 'commons-io', name: 'commons-io', version: '2.11.0'
2930
testCompile group: 'junit', name: 'junit', version: '4.12'
31+
testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.9.0'
3032
}
3133

3234
jacocoTestReport {

ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ModelSerDeSer.java

+19-4
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,27 @@
1313
package org.opensearch.ml.engine.utils;
1414

1515
import lombok.experimental.UtilityClass;
16+
import org.apache.commons.io.serialization.ValidatingObjectInputStream;
1617
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;
1718

1819
import java.io.ByteArrayInputStream;
1920
import java.io.ByteArrayOutputStream;
2021
import java.io.IOException;
21-
import java.io.ObjectInputStream;
2222
import java.io.ObjectOutputStream;
2323

2424
@UtilityClass
2525
public class ModelSerDeSer {
26+
// Welcome list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
27+
public static final String[] ACCEPT_CLASS_PATTERNS = {
28+
"java.lang.*",
29+
"java.util.*",
30+
"java.time.*",
31+
"org.opensearch.ml.*",
32+
"*org.tribuo.*",
33+
"com.oracle.labs.*",
34+
"[*"
35+
};
36+
2637
public static byte[] serialize(Object model) {
2738
byte[] res = new byte[0];
2839
try {
@@ -44,9 +55,13 @@ public static Object deserialize(byte[] modelBin) {
4455
Object res;
4556
try {
4657
ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
47-
ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);
48-
res = objectInputStream.readObject();
49-
objectInputStream.close();
58+
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream);
59+
60+
// Validate the model class type to avoid deserialization attack.
61+
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);
62+
63+
res = validatingObjectInputStream.readObject();
64+
validatingObjectInputStream.close();
5065
inputStream.close();
5166
} catch (IOException | ClassNotFoundException e) {
5267
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());

ml-algorithms/src/test/java/org/opensearch/ml/engine/ModelSerDeSerTest.java

+26-5
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,42 @@
1212

1313
package org.opensearch.ml.engine;
1414

15+
import org.junit.Rule;
1516
import org.junit.Test;
17+
import org.junit.rules.ExpectedException;
18+
import org.opensearch.ml.engine.algorithms.clustering.KMeans;
19+
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;
1620
import org.opensearch.ml.engine.utils.ModelSerDeSer;
21+
import org.tribuo.clustering.kmeans.KMeansModel;
1722

18-
import static org.junit.Assert.assertTrue;
23+
import java.util.ArrayList;
24+
import java.util.Arrays;
1925

20-
import java.io.IOException;
26+
import static org.junit.Assert.assertFalse;
27+
import static org.junit.Assert.assertTrue;
28+
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;
2129

2230
public class ModelSerDeSerTest {
23-
private final DummyModel dummyModel = new DummyModel();
31+
@Rule
32+
public ExpectedException thrown = ExpectedException.none();
33+
34+
private final Object dummyModel = new Object();
2435

2536
@Test
26-
public void testModelSerDeSer() throws IOException, ClassNotFoundException {
37+
public void testModelSerDeSerBlocklModel() {
38+
thrown.expect(ModelSerDeSerException.class);
2739
byte[] modelBin = ModelSerDeSer.serialize(dummyModel);
28-
DummyModel model = (DummyModel) ModelSerDeSer.deserialize(modelBin);
40+
Object model = ModelSerDeSer.deserialize(modelBin);
2941
assertTrue(model.equals(dummyModel));
3042
}
3143

44+
@Test
45+
public void testModelSerDeSerKMeans() {
46+
KMeans kMeans = new KMeans(new ArrayList<>());
47+
Model model = kMeans.train(constructKMeansDataFrame(100));
48+
49+
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.content);
50+
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
51+
assertFalse(Arrays.equals(serializedModel, model.content));
52+
}
3253
}

plugin/src/main/java/org/opensearch/ml/indices/MLInputDatasetHandler.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<
7474

7575
client.search(searchRequest, ActionListener.wrap(r -> {
7676
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
77-
// todo: add specific exception
78-
listener.onFailure(new RuntimeException("No document found"));
77+
listener.onFailure(new IllegalArgumentException("No document found"));
7978
return;
8079
}
8180
SearchHits hits = r.getHits();

plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionIT.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.opensearch.ResourceNotFoundException;
2222
import org.opensearch.action.ActionFuture;
2323
import org.opensearch.action.ActionRequestValidationException;
24-
import org.opensearch.common.io.stream.NotSerializableExceptionWrapper;
2524
import org.opensearch.ml.common.dataset.MLInputDataset;
2625
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
2726
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
@@ -99,6 +98,6 @@ public void testPredictionWithEmptyDataset() throws IOException {
9998
emptySearchInputDataset
10099
);
101100
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
102-
expectThrows(NotSerializableExceptionWrapper.class, () -> predictionFuture.actionGet());
101+
expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
103102
}
104103
}

0 commit comments

Comments
 (0)