Skip to content

Commit 7da0029

Browse files
committed
Fix build
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent da982a0 commit 7da0029

File tree

5 files changed

+13
-52
lines changed

5 files changed

+13
-52
lines changed

common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java

-6
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ public MLUpdateModelInput(StreamInput in) throws IOException {
134134

135135
@Override
136136
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
137-
builder.startObject();
138-
builder.endObject();
139-
return builder;
140-
}
141-
142-
public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Params params) throws IOException {
143137
builder.startObject();
144138
builder.field(MODEL_ID_FIELD, modelId);
145139
if (name != null) {

common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java

+4-24
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public class MLUpdateModelInputTest {
8484
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\""
8585
+
8686
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" +
87-
"\"test-connector_id\",\"interface\":{}}";
87+
"\"test-connector_id\"}";
8888

8989
@Rule
9090
public ExpectedException exceptionRule = ExpectedException.none();
@@ -166,30 +166,16 @@ public void readInputStreamSuccessWithNullFields() throws IOException {
166166
@Test
167167
public void testToXContent() throws Exception {
168168
String jsonStr = serializationWithToXContent(updateModelInput);
169-
assertEquals("{}", jsonStr);
170-
}
171-
172-
@Test
173-
public void testToXContentForUpdateRequestDoc() throws Exception {
174-
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
175169
assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
176170
}
177171

178-
@Test
179-
public void testToXContenttForUpdateRequestDocIncomplete() throws Exception {
180-
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
181-
updateModelInput = MLUpdateModelInput.builder()
182-
.modelId("test-model_id").build();
183-
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
184-
assertEquals(expectedIncompleteInputStr, jsonStr);
185-
}
186-
187172
@Test
188173
public void testToXContentIncomplete() throws Exception {
174+
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
189175
updateModelInput = MLUpdateModelInput.builder()
190176
.modelId("test-model_id").build();
191177
String jsonStr = serializationWithToXContent(updateModelInput);
192-
assertEquals("{}", jsonStr);
178+
assertEquals(expectedIncompleteInputStr, jsonStr);
193179
}
194180

195181
@Test
@@ -238,7 +224,7 @@ public void parseWithIllegalFieldWithoutModel() throws Exception {
238224
"\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}";
239225
testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
240226
try {
241-
assertEquals(expectedOutputStr, serializationWithToXContentForUpdateRequestDoc(parsedInput));
227+
assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput));
242228
} catch (IOException e) {
243229
throw new RuntimeException(e);
244230
}
@@ -268,10 +254,4 @@ private String serializationWithToXContent(MLUpdateModelInput input) throws IOEx
268254
return builder.toString();
269255
}
270256

271-
private String serializationWithToXContentForUpdateRequestDoc(MLUpdateModelInput input) throws IOException {
272-
XContentBuilder builder = XContentFactory.jsonBuilder();
273-
input.toXContentForUpdateRequestDoc(builder, ToXContent.EMPTY_PARAMS);
274-
assertNotNull(builder);
275-
return builder.toString();
276-
}
277257
}

plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ private void buildUpdateRequest(
379379
) {
380380
try {
381381
updateModelInput.setLastUpdateTime(Instant.now());
382-
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
382+
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
383383
updateRequest.docAsUpsert(true);
384384
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
385385
if (isUpdateModelCache) {
@@ -420,7 +420,7 @@ private void buildUpdateRequest(
420420
Integer.parseInt(updatedVersion)
421421
);
422422
try {
423-
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
423+
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
424424
updateRequest.docAsUpsert(true);
425425
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
426426
if (isUpdateModelCache) {

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+5-14
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
99
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
1010

11-
import java.io.IOException;
12-
1311
import org.opensearch.OpenSearchStatusException;
1412
import org.opensearch.action.ActionRequest;
1513
import org.opensearch.action.support.ActionFilters;
@@ -43,9 +41,6 @@
4341
import org.opensearch.tasks.Task;
4442
import org.opensearch.transport.TransportService;
4543

46-
import com.fasterxml.jackson.databind.JsonNode;
47-
import com.fasterxml.jackson.databind.ObjectMapper;
48-
4944
import lombok.AccessLevel;
5045
import lombok.experimental.FieldDefaults;
5146
import lombok.extern.log4j.Log4j2;
@@ -242,19 +237,15 @@ private void validateInputSchema(String modelId, MLInput mlInput) {
242237
if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) {
243238
String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input");
244239
try {
245-
String parametersString = parametersObjectExtractor(mlInput);
246-
MLNodeUtils.validateSchema(inputSchemaString, parametersString);
240+
MLNodeUtils
241+
.validateSchema(
242+
inputSchemaString,
243+
mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()
244+
);
247245
} catch (Exception e) {
248246
throw new IllegalArgumentException("Error validating input schema: " + e.getMessage());
249247
}
250248
}
251249
}
252250

253-
private static String parametersObjectExtractor(MLInput mlInput) throws IOException {
254-
String mlInputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
255-
ObjectMapper mapper = new ObjectMapper();
256-
JsonNode mlInputObject = mapper.readTree(mlInputString);
257-
return mlInputObject.get("parameters").asText();
258-
}
259-
260251
}

plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,7 @@ public void testUpdateRequestDocIOException() throws IOException {
651651
doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm();
652652
doReturn(MLModelState.REGISTERED).when(mockModel).getModelState();
653653

654-
doThrow(new IOException("Exception occurred during building update request."))
655-
.when(mockUpdateModelInput)
656-
.toXContentForUpdateRequestDoc(any(), any());
654+
doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
657655
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
658656
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
659657
verify(actionListener).onFailure(argumentCaptor.capture());
@@ -702,9 +700,7 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO
702700
return null;
703701
}).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class));
704702

705-
doThrow(new IOException("Exception occurred during building update request."))
706-
.when(mockUpdateModelInput)
707-
.toXContentForUpdateRequestDoc(any(), any());
703+
doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
708704
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
709705
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
710706
verify(actionListener).onFailure(argumentCaptor.capture());

0 commit comments

Comments
 (0)