Skip to content

Commit 713dcba

Browse files
authored
fix register client API (opensearch-project#1560)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 0920ba7 commit 713dcba

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
230230
@Override
231231
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
232232
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
233-
client
234-
.execute(
235-
MLRegisterModelAction.INSTANCE,
236-
registerRequest,
237-
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
238-
);
233+
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener));
239234
}
240235

241236
@Override
@@ -266,6 +261,16 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
266261
return actionListener;
267262
}
268263

264+
private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener(
265+
ActionListener<MLRegisterModelResponse> listener
266+
) {
267+
ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> {
268+
MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res);
269+
return registerModelResponse;
270+
});
271+
return actionListener;
272+
}
273+
269274
private <T extends ActionResponse> ActionListener<T> wrapActionListener(
270275
final ActionListener<T> listener,
271276
final Function<ActionResponse, T> recreate

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java

+22
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@
77

88
import lombok.Getter;
99
import org.opensearch.core.action.ActionResponse;
10+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
11+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
1012
import org.opensearch.core.common.io.stream.StreamInput;
1113
import org.opensearch.core.common.io.stream.StreamOutput;
1214
import org.opensearch.core.xcontent.ToXContent;
1315
import org.opensearch.core.xcontent.ToXContentObject;
1416
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.ml.common.transport.MLTaskResponse;
1518

19+
import java.io.ByteArrayInputStream;
20+
import java.io.ByteArrayOutputStream;
1621
import java.io.IOException;
22+
import java.io.UncheckedIOException;
1723

1824
@Getter
1925
public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject {
@@ -61,4 +67,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
6167
builder.endObject();
6268
return builder;
6369
}
70+
71+
public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) {
72+
if (actionResponse instanceof MLRegisterModelResponse) {
73+
return (MLRegisterModelResponse) actionResponse;
74+
}
75+
76+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
77+
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
78+
actionResponse.writeTo(osso);
79+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
80+
return new MLRegisterModelResponse(input);
81+
}
82+
} catch (IOException e) {
83+
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e);
84+
}
85+
}
6486
}

0 commit comments

Comments
 (0)