Skip to content

Commit 3c9a76d

Browse files
Fix internal connector (#1989) (#1992)
* Fix internal connector Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * spotless fix Signed-off-by: Sicheng Song <sicheng.song@outlook.com> --------- Signed-off-by: Sicheng Song <sicheng.song@outlook.com> (cherry picked from commit 343ae16) Co-authored-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 773f3e6 commit 3c9a76d

File tree

4 files changed

+90
-9
lines changed

4 files changed

+90
-9
lines changed

common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java

+38
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,44 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
140140
return builder;
141141
}
142142

143+
public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Params params) throws IOException {
144+
builder.startObject();
145+
builder.field(MODEL_ID_FIELD, modelId);
146+
if (name != null) {
147+
builder.field(MODEL_NAME_FIELD, name);
148+
}
149+
if (description != null) {
150+
builder.field(DESCRIPTION_FIELD, description);
151+
}
152+
if (version != null) {
153+
builder.field(MODEL_VERSION_FIELD, version);
154+
}
155+
if (modelGroupId != null) {
156+
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
157+
}
158+
if (isEnabled != null) {
159+
builder.field(IS_ENABLED_FIELD, isEnabled);
160+
}
161+
if (rateLimiter != null) {
162+
builder.field(RATE_LIMITER_FIELD, rateLimiter);
163+
}
164+
if (modelConfig != null) {
165+
builder.field(MODEL_CONFIG_FIELD, modelConfig);
166+
}
167+
// Notice that we serialize the updatedConnector to the connector field, in order to be compatible with original internal connector field format.
168+
if (updatedConnector != null) {
169+
builder.field(CONNECTOR_FIELD, updatedConnector);
170+
}
171+
if (connectorId != null) {
172+
builder.field(CONNECTOR_ID_FIELD, connectorId);
173+
}
174+
if (lastUpdateTime != null) {
175+
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
176+
}
177+
builder.endObject();
178+
return builder;
179+
}
180+
143181
@Override
144182
public void writeTo(StreamOutput out) throws IOException {
145183
out.writeString(modelId);

common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java

+39
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,23 @@ public class MLUpdateModelInputTest {
6161
+
6262
"\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}";
6363

64+
private final String expectedOutputStrForUpdateRequestDoc = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":"
65+
+
66+
"\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" +
67+
"{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" +
68+
"{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\""
69+
+
70+
"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" +
71+
"{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
72+
+
73+
"{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
74+
+
75+
"\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
76+
+
77+
"\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
78+
+
79+
"\"test-connector_id\",\"last_updated_time\":1}";
80+
6481
private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
6582
+
6683
"\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" +
@@ -153,6 +170,21 @@ public void testToXContent() throws Exception {
153170
assertEquals(expectedInputStr, jsonStr);
154171
}
155172

173+
@Test
174+
public void testToXContentForUpdateRequestDoc() throws Exception {
175+
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
176+
assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
177+
}
178+
179+
@Test
180+
public void testToXContenttForUpdateRequestDocIncomplete() throws Exception {
181+
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
182+
updateModelInput = MLUpdateModelInput.builder()
183+
.modelId("test-model_id").build();
184+
String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput);
185+
assertEquals(expectedIncompleteInputStr, jsonStr);
186+
}
187+
156188
@Test
157189
public void testToXContentIncomplete() throws Exception {
158190
String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}";
@@ -237,4 +269,11 @@ private String serializationWithToXContent(MLUpdateModelInput input) throws IOEx
237269
assertNotNull(builder);
238270
return builder.toString();
239271
}
272+
273+
private String serializationWithToXContentForUpdateRequestDoc(MLUpdateModelInput input) throws IOException {
274+
XContentBuilder builder = XContentFactory.jsonBuilder();
275+
input.toXContentForUpdateRequestDoc(builder, ToXContent.EMPTY_PARAMS);
276+
assertNotNull(builder);
277+
return builder.toString();
278+
}
240279
}

plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ private void updateRemoteOrTextEmbeddingModel(
200200
MLModel mlModel,
201201
User user,
202202
ActionListener<UpdateResponse> wrappedListener
203-
) {
203+
) throws IOException {
204204
String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId())
205205
&& !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null;
206206
String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null;
@@ -330,7 +330,7 @@ private void updateModelWithRegisteringToAnotherModelGroup(
330330
.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> {
331331
if (hasNewModelGroupPermission) {
332332
mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> {
333-
updateRequestConstructor(
333+
buildUpdateRequest(
334334
modelId,
335335
newModelGroupId,
336336
updateRequest,
@@ -364,11 +364,11 @@ private void updateModelWithRegisteringToAnotherModelGroup(
364364
wrappedListener.onFailure(exception);
365365
}));
366366
} else {
367-
updateRequestConstructor(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
367+
buildUpdateRequest(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
368368
}
369369
}
370370

371-
private void updateRequestConstructor(
371+
private void buildUpdateRequest(
372372
String modelId,
373373
UpdateRequest updateRequest,
374374
MLUpdateModelInput updateModelInput,
@@ -377,7 +377,7 @@ private void updateRequestConstructor(
377377
) {
378378
try {
379379
updateModelInput.setLastUpdateTime(Instant.now());
380-
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
380+
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
381381
updateRequest.docAsUpsert(true);
382382
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
383383
if (isUpdateModelCache) {
@@ -397,7 +397,7 @@ private void updateRequestConstructor(
397397
}
398398
}
399399

400-
private void updateRequestConstructor(
400+
private void buildUpdateRequest(
401401
String modelId,
402402
String newModelGroupId,
403403
UpdateRequest updateRequest,
@@ -418,7 +418,7 @@ private void updateRequestConstructor(
418418
Integer.parseInt(updatedVersion)
419419
);
420420
try {
421-
updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
421+
updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
422422
updateRequest.docAsUpsert(true);
423423
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
424424
if (isUpdateModelCache) {

plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,9 @@ public void testUpdateRequestDocIOException() throws IOException {
651651
doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm();
652652
doReturn(MLModelState.REGISTERED).when(mockModel).getModelState();
653653

654-
doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
654+
doThrow(new IOException("Exception occurred during building update request."))
655+
.when(mockUpdateModelInput)
656+
.toXContentForUpdateRequestDoc(any(), any());
655657
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
656658
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
657659
verify(actionListener).onFailure(argumentCaptor.capture());
@@ -700,7 +702,9 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO
700702
return null;
701703
}).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class));
702704

703-
doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any());
705+
doThrow(new IOException("Exception occurred during building update request."))
706+
.when(mockUpdateModelInput)
707+
.toXContentForUpdateRequestDoc(any(), any());
704708
transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener);
705709
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IOException.class);
706710
verify(actionListener).onFailure(argumentCaptor.capture());

0 commit comments

Comments
 (0)