Skip to content

Commit 4edcd17

Browse files
authored
Change httpclient to async (opensearch-project#1958) (opensearch-project#2375)
* Change httpclient from sync to async * Change from CRTAsyncHttpClient to NettyAsyncHttpClient * Add publisher to request * Change sync httpclient to async * Handle error case and return error response in actionLListener * Fix no response when exception * Add content type header * Fix issues found in functional test * Fix no response issue in functional test * fix default step size error * Add track inference duration for async httpclient * Change client appsec highlight issues implementation for async httpclient * Add UTs * Add UTs * Remove unused file * Add UTs * format code * Change error code to honor remote service error code * Add more UTs * Change SSRF code to make it correct for return error stattus * Fix failure UTs and add more UTs * Fix failure ITs * format code * Fix partial success response not correct issue * format code * Fix failure ITs * Add more UTs to increase code coverage * Change url regex * Address comments * format code * Fix failure UTs * Add UT for httpclientFactory throw exception when creating httpclient * format code * Address comments and add modelTensor status code * Address comments * format code * Add status code to process error response * format code * Rebase main after connector level http parameter support * Fix UT * Change error message when remote model return empty and chaange the behavior when one of the requests fails * Add comments\ * Remove redundant builder and change the error code check * format code * Add more UTs for throw exception cases * fix failure UTs * format code * Fix test cases since the error message change * Rebase code * fix failure IT * Add more UTs * Fix duplicate response to client issue * fix duplicate response in channel * change code for all successfully responses case * Address comments * format code * Increase nio httpclient version to fix vulnerbility * Change validate localhost logic to same with existing code * change method signature to private * format code --------- Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent ef435c9 commit 4edcd17

21 files changed

+1298
-759
lines changed

ml-algorithms/build.gradle

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,13 @@ dependencies {
6363
}
6464
}
6565

66-
implementation platform('software.amazon.awssdk:bom:2.21.15')
66+
implementation platform('software.amazon.awssdk:bom:2.25.40')
6767
implementation 'software.amazon.awssdk:auth'
6868
implementation 'software.amazon.awssdk:apache-client'
6969
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
7070
implementation 'com.jayway.jsonpath:json-path:2.9.0'
7171
implementation group: 'org.json', name: 'json', version: '20231013'
72+
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40'
7273
}
7374

7475
lombok {

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
import java.util.Map;
99

10+
import org.opensearch.core.action.ActionListener;
1011
import org.opensearch.ml.common.MLModel;
1112
import org.opensearch.ml.common.input.MLInput;
1213
import org.opensearch.ml.common.output.MLOutput;
14+
import org.opensearch.ml.common.transport.MLTaskResponse;
1315
import org.opensearch.ml.engine.encryptor.Encryptor;
1416

1517
/**
@@ -31,7 +33,13 @@ public interface Predictable {
3133
* @param mlInput input data
3234
* @return predicted results
3335
*/
34-
MLOutput predict(MLInput mlInput);
36+
default MLOutput predict(MLInput mlInput) {
37+
throw new IllegalStateException("Method is not implemented");
38+
}
39+
40+
default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
41+
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
42+
}
3543

3644
/**
3745
* Init model (load model into memory) with ML model content and params.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

+36-98
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,43 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8-
import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
98
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
10-
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
119
import static software.amazon.awssdk.http.SdkHttpMethod.POST;
1210

13-
import java.io.BufferedReader;
14-
import java.io.InputStreamReader;
15-
import java.net.URI;
16-
import java.nio.charset.StandardCharsets;
1711
import java.security.AccessController;
1812
import java.security.PrivilegedExceptionAction;
1913
import java.time.Duration;
2014
import java.util.List;
2115
import java.util.Map;
16+
import java.util.concurrent.CompletableFuture;
2217

23-
import org.opensearch.OpenSearchStatusException;
2418
import org.opensearch.client.Client;
2519
import org.opensearch.common.util.TokenBucket;
26-
import org.opensearch.core.rest.RestStatus;
20+
import org.opensearch.core.action.ActionListener;
2721
import org.opensearch.ml.common.connector.AwsConnector;
2822
import org.opensearch.ml.common.connector.Connector;
2923
import org.opensearch.ml.common.exception.MLException;
3024
import org.opensearch.ml.common.input.MLInput;
3125
import org.opensearch.ml.common.model.MLGuard;
3226
import org.opensearch.ml.common.output.model.ModelTensors;
3327
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
28+
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
3429
import org.opensearch.script.ScriptService;
3530

3631
import lombok.Getter;
3732
import lombok.Setter;
3833
import lombok.extern.log4j.Log4j2;
39-
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
40-
import software.amazon.awssdk.core.sync.RequestBody;
41-
import software.amazon.awssdk.http.AbortableInputStream;
42-
import software.amazon.awssdk.http.HttpExecuteRequest;
43-
import software.amazon.awssdk.http.HttpExecuteResponse;
44-
import software.amazon.awssdk.http.SdkHttpClient;
45-
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
34+
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
4635
import software.amazon.awssdk.http.SdkHttpFullRequest;
47-
import software.amazon.awssdk.utils.AttributeMap;
36+
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
37+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
4838

4939
@Log4j2
5040
@ConnectorExecutor(AWS_SIGV4)
5141
public class AwsConnectorExecutor extends AbstractConnectorExecutor {
5242

5343
@Getter
5444
private AwsConnector connector;
55-
private SdkHttpClient httpClient;
5645
@Setter
5746
@Getter
5847
private ScriptService scriptService;
@@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
6958
@Getter
7059
private MLGuard mlGuard;
7160

72-
public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
73-
this.connector = (AwsConnector) connector;
74-
this.httpClient = httpClient;
75-
}
61+
private SdkAsyncHttpClient httpClient;
7662

7763
public AwsConnectorExecutor(Connector connector) {
7864
super.initialize(connector);
7965
this.connector = (AwsConnector) connector;
80-
Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout());
81-
Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout());
82-
try (
83-
AttributeMap attributeMap = AttributeMap
84-
.builder()
85-
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
86-
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
87-
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections())
88-
.build()
89-
) {
90-
log
91-
.info(
92-
"Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}",
93-
connectionTimeout,
94-
readTimeout,
95-
super.getConnectorClientConfig().getMaxConnections()
96-
);
97-
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
98-
} catch (RuntimeException e) {
99-
log.error("Error initializing AWS connector HTTP client.", e);
100-
throw e;
101-
} catch (Throwable e) {
102-
log.error("Error initializing AWS connector HTTP client.", e);
103-
throw new MLException(e);
104-
}
66+
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
67+
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
68+
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
69+
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
10570
}
10671

10772
@SuppressWarnings("removal")
10873
@Override
109-
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
74+
public void invokeRemoteModel(
75+
MLInput mlInput,
76+
Map<String, String> parameters,
77+
String payload,
78+
Map<Integer, ModelTensors> tensorOutputs,
79+
ExecutionContext countDownLatch,
80+
ActionListener<List<ModelTensors>> actionListener
81+
) {
11082
try {
111-
String endpoint = connector.getPredictEndpoint(parameters);
112-
RequestBody requestBody = RequestBody.fromString(payload);
113-
114-
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
115-
.builder()
116-
.method(POST)
117-
.uri(URI.create(endpoint))
118-
.contentStreamProvider(requestBody.contentStreamProvider());
119-
Map<String, String> headers = connector.getDecryptedHeaders();
120-
if (headers != null) {
121-
for (String key : headers.keySet()) {
122-
builder.putHeader(key, headers.get(key));
123-
}
124-
}
125-
SdkHttpFullRequest request = builder.build();
126-
HttpExecuteRequest executeRequest = HttpExecuteRequest
83+
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
84+
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
12785
.builder()
12886
.request(signRequest(request))
129-
.contentStreamProvider(request.contentStreamProvider().orElse(null))
87+
.requestContentPublisher(new SimpleHttpContentPublisher(request))
88+
.responseHandler(
89+
new MLSdkAsyncHttpResponseHandler(
90+
countDownLatch,
91+
actionListener,
92+
parameters,
93+
tensorOutputs,
94+
connector,
95+
scriptService,
96+
mlGuard
97+
)
98+
)
13099
.build();
131-
132-
HttpExecuteResponse response = AccessController
133-
.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> httpClient.prepareRequest(executeRequest).call());
134-
int statusCode = response.httpResponse().statusCode();
135-
136-
AbortableInputStream body = null;
137-
if (response.responseBody().isPresent()) {
138-
body = response.responseBody().get();
139-
}
140-
141-
StringBuilder responseBuilder = new StringBuilder();
142-
if (body != null) {
143-
try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) {
144-
String line;
145-
while ((line = reader.readLine()) != null) {
146-
responseBuilder.append(line);
147-
}
148-
}
149-
} else {
150-
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
151-
}
152-
String modelResponse = responseBuilder.toString();
153-
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
154-
throw new IllegalArgumentException("guardrails triggered for LLM output");
155-
}
156-
if (statusCode < 200 || statusCode >= 300) {
157-
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
158-
}
159-
160-
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
161-
tensors.setStatusCode(statusCode);
162-
tensorOutputs.add(tensors);
100+
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
163101
} catch (RuntimeException exception) {
164102
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
165-
throw exception;
103+
actionListener.onFailure(exception);
166104
} catch (Throwable e) {
167105
log.error("Failed to execute predict in aws connector", e);
168-
throw new MLException("Fail to execute predict in aws connector", e);
106+
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
169107
}
170108
}
171109

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

+48-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;
1616

1717
import java.io.IOException;
18+
import java.net.URI;
19+
import java.nio.charset.Charset;
1820
import java.util.ArrayList;
1921
import java.util.HashMap;
2022
import java.util.List;
@@ -34,6 +36,7 @@
3436
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
3537
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3638
import org.opensearch.ml.common.input.MLInput;
39+
import org.opensearch.ml.common.model.MLGuard;
3740
import org.opensearch.ml.common.output.model.ModelTensor;
3841
import org.opensearch.ml.common.output.model.ModelTensors;
3942
import org.opensearch.script.ScriptService;
@@ -46,7 +49,9 @@
4649
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
4750
import software.amazon.awssdk.auth.signer.Aws4Signer;
4851
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
52+
import software.amazon.awssdk.core.sync.RequestBody;
4953
import software.amazon.awssdk.http.SdkHttpFullRequest;
54+
import software.amazon.awssdk.http.SdkHttpMethod;
5055
import software.amazon.awssdk.regions.Region;
5156

5257
@Log4j2
@@ -179,11 +184,15 @@ public static ModelTensors processOutput(
179184
String modelResponse,
180185
Connector connector,
181186
ScriptService scriptService,
182-
Map<String, String> parameters
187+
Map<String, String> parameters,
188+
MLGuard mlGuard
183189
) throws IOException {
184190
if (modelResponse == null) {
185191
throw new IllegalArgumentException("model response is null");
186192
}
193+
if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) {
194+
throw new IllegalArgumentException("guardrails triggered for LLM output");
195+
}
187196
List<ModelTensor> modelTensors = new ArrayList<>();
188197
Optional<ConnectorAction> predictAction = connector.findPredictAction();
189198
if (predictAction.isEmpty()) {
@@ -252,4 +261,42 @@ public static SdkHttpFullRequest signRequest(
252261

253262
return signer.sign(request, params);
254263
}
264+
265+
public static SdkHttpFullRequest buildSdkRequest(
266+
Connector connector,
267+
Map<String, String> parameters,
268+
String payload,
269+
SdkHttpMethod method
270+
) {
271+
String charset = parameters.getOrDefault("charset", "UTF-8");
272+
RequestBody requestBody;
273+
if (payload != null) {
274+
requestBody = RequestBody.fromString(payload, Charset.forName(charset));
275+
} else {
276+
requestBody = RequestBody.empty();
277+
}
278+
if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) {
279+
log.error("Content length is 0. Aborting request to remote model");
280+
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
281+
}
282+
String endpoint = connector.getPredictEndpoint(parameters);
283+
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
284+
.builder()
285+
.method(method)
286+
.uri(URI.create(endpoint))
287+
.contentStreamProvider(requestBody.contentStreamProvider());
288+
Map<String, String> headers = connector.getDecryptedHeaders();
289+
if (headers != null) {
290+
for (String key : headers.keySet()) {
291+
builder.putHeader(key, headers.get(key));
292+
}
293+
}
294+
if (builder.matchingHeaders("Content-Type").isEmpty()) {
295+
builder.putHeader("Content-Type", "application/json");
296+
}
297+
if (builder.matchingHeaders("Content-Length").isEmpty()) {
298+
builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString());
299+
}
300+
return builder.build();
301+
}
255302
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
8+
package org.opensearch.ml.engine.algorithms.remote;
9+
10+
import java.util.concurrent.CountDownLatch;
11+
import java.util.concurrent.atomic.AtomicReference;
12+
13+
import lombok.AllArgsConstructor;
14+
import lombok.Data;
15+
16+
/**
17+
* This class encapsulates several parameters that are used in a split-batch request case.
18+
* A batch request is that in neural-search side multiple fields are send in one request to ml-commons,
19+
* but the remote model doesn't accept list of string inputs so in ml-commons the request needs split.
20+
* sequence is used to identify the index of the split request.
21+
* countDownLatch is used to wait for all the split requests to finish.
22+
* exceptionHolder is used to hold any exception thrown in a split-batch request.
23+
*/
24+
@Data
25+
@AllArgsConstructor
26+
public class ExecutionContext {
27+
// Should never be null
28+
private int sequence;
29+
private CountDownLatch countDownLatch;
30+
// This is to hold any exception thrown in a split-batch request
31+
private AtomicReference<Exception> exceptionHolder;
32+
}

0 commit comments

Comments
 (0)