Skip to content

Commit 512b8da

Browse files
committed
fix master key initialization causing 5 seconds block
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 7c4f9ff commit 512b8da

File tree

12 files changed

+608
-228
lines changed

12 files changed

+608
-228
lines changed

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
import java.util.List;
1414
import java.util.Map;
1515
import java.util.Optional;
16+
import java.util.function.BiConsumer;
17+
import java.util.function.BiFunction;
1618
import java.util.function.Function;
1719
import java.util.regex.Matcher;
1820
import java.util.regex.Pattern;
1921
import org.apache.commons.text.StringSubstitutor;
22+
import org.opensearch.core.action.ActionListener;
2023
import org.opensearch.core.common.io.stream.StreamInput;
2124
import org.opensearch.core.common.io.stream.StreamOutput;
2225
import org.opensearch.core.common.io.stream.Writeable;
@@ -62,8 +65,8 @@ public interface Connector extends ToXContentObject, Writeable {
6265

6366
<T> T createPayload(String action, Map<String, String> parameters);
6467

65-
void decrypt(String action, Function<String, String> function);
66-
void encrypt(Function<String, String> function);
68+
void decrypt(String action, BiConsumer<String, ActionListener<String>> function, ActionListener<String> listener);
69+
void encrypt(BiConsumer<String, ActionListener<String>> function, ActionListener<String> listener);
6770

6871
Connector cloneConnector();
6972

@@ -73,7 +76,7 @@ public interface Connector extends ToXContentObject, Writeable {
7376

7477
void writeTo(StreamOutput out) throws IOException;
7578

76-
void update(MLCreateConnectorInput updateContent, Function<String, String> function);
79+
void update(MLCreateConnectorInput updateContent, BiConsumer<String, ActionListener<String>> function, ActionListener<String> listener);
7780

7881
<T> void parseResponse(T orElse, List<ModelTensor> modelTensors, boolean b) throws IOException;
7982

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+70-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.commons.text.StringSubstitutor;
1313
import org.opensearch.common.io.stream.BytesStreamOutput;
1414
import org.opensearch.commons.authuser.User;
15+
import org.opensearch.core.action.ActionListener;
1516
import org.opensearch.core.common.io.stream.StreamInput;
1617
import org.opensearch.core.common.io.stream.StreamOutput;
1718
import org.opensearch.core.xcontent.XContentBuilder;
@@ -25,6 +26,11 @@
2526
import java.util.List;
2627
import java.util.Map;
2728
import java.util.Optional;
29+
import java.util.concurrent.atomic.AtomicBoolean;
30+
import java.util.concurrent.atomic.AtomicInteger;
31+
import java.util.function.BiConsumer;
32+
import java.util.function.BiFunction;
33+
import java.util.function.Consumer;
2834
import java.util.function.Function;
2935
import java.util.regex.Matcher;
3036
import java.util.regex.Pattern;
@@ -273,7 +279,7 @@ public void writeTo(StreamOutput out) throws IOException {
273279
}
274280

275281
@Override
276-
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
282+
public void update(MLCreateConnectorInput updateContent, BiConsumer<String, ActionListener<String>> consumer, ActionListener<String> listener) {
277283
if (updateContent.getName() != null) {
278284
this.name = updateContent.getName();
279285
}
@@ -291,7 +297,7 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
291297
}
292298
if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) {
293299
this.credential = updateContent.getCredential();
294-
encrypt(function);
300+
encrypt(consumer, listener);
295301
}
296302
if (updateContent.getActions() != null) {
297303
this.actions = updateContent.getActions();
@@ -349,15 +355,42 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
349355
}
350356

351357
@Override
352-
public void decrypt(String action, Function<String, String> function) {
358+
public void decrypt(String action, BiConsumer<String, ActionListener<String>> consumer, ActionListener<String> listener) {
353359
Map<String, String> decrypted = new HashMap<>();
360+
AtomicBoolean completed = new AtomicBoolean(false);
361+
354362
for (String key : credential.keySet()) {
355-
decrypted.put(key, function.apply(credential.get(key)));
356-
}
357-
this.decryptedCredential = decrypted;
358-
Optional<ConnectorAction> connectorAction = findAction(action);
359-
Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
360-
this.decryptedHeaders = createDecryptedHeaders(headers);
363+
consumer.accept(credential.get(key), new ActionListener<>() {
364+
@Override
365+
public void onResponse(String decryptedValue) {
366+
decrypted.put(key, decryptedValue);
367+
if (decrypted.size() == credential.size() && !completed.get()) {
368+
completed.set(true);
369+
decryptedCredential = decrypted;
370+
Optional<ConnectorAction> connectorAction = findAction(action);
371+
Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
372+
decryptedHeaders = createDecryptedHeaders(headers);
373+
listener.onResponse("All credentials encrypted successfully"); // Notify that decryption is complete
374+
}
375+
}
376+
377+
@Override
378+
public void onFailure(Exception e) {
379+
log.error("Failed to decrypt credential for key: " + key, e);
380+
if (!completed.getAndSet(true)) {
381+
listener.onFailure(e);
382+
}
383+
}
384+
});
385+
}
386+
// Map<String, String> decrypted = new HashMap<>();
387+
// for (String key : credential.keySet()) {
388+
// decrypted.put(key, function.apply(credential.get(key)));
389+
// }
390+
// this.decryptedCredential = decrypted;
391+
// Optional<ConnectorAction> connectorAction = findAction(action);
392+
// Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
393+
// this.decryptedHeaders = createDecryptedHeaders(headers);
361394
}
362395

363396
@Override
@@ -372,10 +405,35 @@ public Connector cloneConnector() {
372405
}
373406

374407
@Override
375-
public void encrypt(Function<String, String> function) {
408+
// public void encrypt(Function<String, String> function) {
409+
// for (String key : credential.keySet()) {
410+
// String encrypted = function.apply(credential.get(key));
411+
// credential.put(key, encrypted);
412+
// }
413+
// }
414+
415+
public void encrypt(BiConsumer<String, ActionListener<String>> consumer, ActionListener<String> listener) {
416+
AtomicBoolean completed = new AtomicBoolean(false);
417+
376418
for (String key : credential.keySet()) {
377-
String encrypted = function.apply(credential.get(key));
378-
credential.put(key, encrypted);
419+
consumer.accept(credential.get(key), new ActionListener<>() {
420+
@Override
421+
public void onResponse(String encrypted) {
422+
credential.put(key, encrypted);
423+
if (credential.entrySet().stream().allMatch(entry -> entry.getValue() != null)) {
424+
completed.set(true);
425+
listener.onResponse("All credentials encrypted successfully");
426+
}
427+
}
428+
429+
@Override
430+
public void onFailure(Exception e) {
431+
log.error("Failed to encrypt credential for key: " + key, e);
432+
if (!completed.getAndSet(true)) {
433+
listener.onFailure(e);
434+
}
435+
}
436+
});
379437
}
380438
}
381439

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ private void validateInput(Input input) {
197197
}
198198
}
199199

200-
public String encrypt(String credential) {
201-
return encryptor.encrypt(credential);
200+
public void encrypt(String credential, ActionListener<String> listener) {
201+
encryptor.encrypt(credential, listener);
202202
}
203203

204204
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+34-10
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,40 @@ public boolean isModelReady() {
9494
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
9595
try {
9696
Connector connector = model.getConnector().cloneConnector();
97-
connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
98-
this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
99-
this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
100-
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
101-
this.connectorExecutor.setClient((Client) params.get(CLIENT));
102-
this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
103-
this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER));
104-
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
105-
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
106-
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
97+
ActionListener<String> decryptListener = new ActionListener<>() {
98+
@Override
99+
public void onResponse(String response) {
100+
try {
101+
connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
102+
connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
103+
connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
104+
connectorExecutor.setClient((Client) params.get(CLIENT));
105+
connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
106+
connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER));
107+
connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
108+
connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
109+
connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
110+
} catch (Exception e) {
111+
log.error("Failed to init remote model.", e);
112+
}
113+
}
114+
115+
@Override
116+
public void onFailure(Exception e) {
117+
log.error("Failed to decrypt connector credentials.", e);
118+
}
119+
};
120+
connector.decrypt(PREDICT.name(), (credential, listener) -> encryptor.decrypt(credential, decryptListener), decryptListener);
121+
// connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential));
122+
// this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
123+
// this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
124+
// this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
125+
// this.connectorExecutor.setClient((Client) params.get(CLIENT));
126+
// this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
127+
// this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER));
128+
// this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
129+
// this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
130+
// this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
107131
} catch (RuntimeException e) {
108132
log.error("Failed to init remote model.", e);
109133
throw e;

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
package org.opensearch.ml.engine.encryptor;
77

8+
import org.opensearch.core.action.ActionListener;
9+
810
public interface Encryptor {
911

1012
/**
@@ -13,15 +15,15 @@ public interface Encryptor {
1315
* @param plainText plainText.
1416
* @return String encryptedText.
1517
*/
16-
String encrypt(String plainText);
18+
void encrypt(String plainText, ActionListener<String> listener);
1719

1820
/**
1921
* Takes encryptedText and returns plain text.
2022
*
2123
* @param encryptedText encryptedText.
2224
* @return String plainText.
2325
*/
24-
String decrypt(String encryptedText);
26+
void decrypt(String encryptedText, ActionListener<String> listener);
2527

2628
/**
2729
* Set up the masterKey for dynamic updating

0 commit comments

Comments
 (0)