Skip to content

Commit e416816

Browse files
Fix user defined preprocess function missing prediction issue (opensearch-project#2418) (opensearch-project#2427)
* Fix user defined preprocess function missing prediction issue Signed-off-by: zane-neo <zaniu@amazon.com> * Add validation to predictAction in connector Signed-off-by: zane-neo <zaniu@amazon.com> * Add check to multi-modal non image case Signed-off-by: zane-neo <zaniu@amazon.com> * add UTs Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com> (cherry picked from commit 89f23d2) Co-authored-by: zane-neo <zaniu@amazon.com>
1 parent 5cbeaa4 commit e416816

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.HashMap;
1212
import java.util.List;
1313
import java.util.Map;
14+
import java.util.Optional;
1415
import java.util.concurrent.ConcurrentHashMap;
1516
import java.util.concurrent.CountDownLatch;
1617
import java.util.concurrent.atomic.AtomicReference;
@@ -27,6 +28,8 @@
2728
import org.opensearch.core.xcontent.NamedXContentRegistry;
2829
import org.opensearch.ml.common.FunctionName;
2930
import org.opensearch.ml.common.connector.Connector;
31+
import org.opensearch.ml.common.connector.ConnectorAction;
32+
import org.opensearch.ml.common.connector.MLPreProcessFunction;
3033
import org.opensearch.ml.common.dataset.MLInputDataset;
3134
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
3235
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -99,6 +102,15 @@ private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocs
99102
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
100103
}
101104
} else {
105+
Optional<ConnectorAction> predictAction = getConnector().findPredictAction();
106+
if (predictAction.isEmpty()) {
107+
throw new IllegalArgumentException("no predict action found");
108+
}
109+
String preProcessFunction = predictAction.get().getPreProcessFunction();
110+
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
111+
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
112+
return Tuple.tuple(textDocsLength, 1);
113+
}
102114
// consider as batch.
103115
return Tuple.tuple(1, textDocsLength);
104116
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

+84
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.junit.Assert.assertEquals;
9+
import static org.mockito.ArgumentMatchers.any;
910
import static org.mockito.Mockito.spy;
1011
import static org.mockito.Mockito.times;
1112
import static org.mockito.Mockito.when;
@@ -30,6 +31,7 @@
3031
import org.opensearch.common.settings.Settings;
3132
import org.opensearch.common.util.concurrent.ThreadContext;
3233
import org.opensearch.core.action.ActionListener;
34+
import org.opensearch.ingest.TestTemplateService;
3335
import org.opensearch.ml.common.FunctionName;
3436
import org.opensearch.ml.common.connector.AwsConnector;
3537
import org.opensearch.ml.common.connector.Connector;
@@ -42,6 +44,7 @@
4244
import org.opensearch.ml.common.transport.MLTaskResponse;
4345
import org.opensearch.ml.engine.encryptor.Encryptor;
4446
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
47+
import org.opensearch.script.ScriptService;
4548
import org.opensearch.threadpool.ThreadPool;
4649

4750
import com.google.common.collect.ImmutableList;
@@ -67,10 +70,15 @@ public class AwsConnectorExecutorTest {
6770

6871
Encryptor encryptor;
6972

73+
@Mock
74+
private ScriptService scriptService;
75+
7076
@Before
7177
public void setUp() {
7278
MockitoAnnotations.openMocks(this);
7379
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
80+
when(scriptService.compile(any(), any()))
81+
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
7482
}
7583

7684
@Test
@@ -282,4 +290,80 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg
282290
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
283291
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
284292
}
293+
294+
@Test
295+
public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() {
296+
ConnectorAction predictAction = ConnectorAction
297+
.builder()
298+
.actionType(ConnectorAction.ActionType.PREDICT)
299+
.method("POST")
300+
.url("http://openai.com/mock")
301+
.requestBody("{\"input\": ${parameters.input}}")
302+
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
303+
.build();
304+
Map<String, String> credential = ImmutableMap
305+
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
306+
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
307+
Connector connector = AwsConnector
308+
.awsConnectorBuilder()
309+
.name("test connector")
310+
.version("1")
311+
.protocol("http")
312+
.parameters(parameters)
313+
.credential(credential)
314+
.build();
315+
connector.decrypt((c) -> encryptor.decrypt(c));
316+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
317+
Settings settings = Settings.builder().build();
318+
threadContext = new ThreadContext(settings);
319+
when(executor.getClient()).thenReturn(client);
320+
when(client.threadPool()).thenReturn(threadPool);
321+
when(threadPool.getThreadContext()).thenReturn(threadContext);
322+
323+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
324+
executor
325+
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
326+
ArgumentCaptor<Exception> exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class);
327+
Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture());
328+
assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException;
329+
assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage());
330+
}
331+
332+
@Test
333+
public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() {
334+
ConnectorAction predictAction = ConnectorAction
335+
.builder()
336+
.actionType(ConnectorAction.ActionType.PREDICT)
337+
.method("POST")
338+
.url("http://openai.com/mock")
339+
.requestBody("{\"input\": ${parameters.input}}")
340+
.preProcessFunction(
341+
"\n StringBuilder builder = new StringBuilder();\n builder.append(\"\\\"\");\n String first = params.text_docs[0];\n builder.append(first);\n builder.append(\"\\\"\");\n def parameters = \"{\" +\"\\\"text_inputs\\\":\" + builder + \"}\";\n return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"
342+
)
343+
.build();
344+
Map<String, String> credential = ImmutableMap
345+
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
346+
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
347+
Connector connector = AwsConnector
348+
.awsConnectorBuilder()
349+
.name("test connector")
350+
.version("1")
351+
.protocol("http")
352+
.parameters(parameters)
353+
.credential(credential)
354+
.actions(Arrays.asList(predictAction))
355+
.build();
356+
connector.decrypt((c) -> encryptor.decrypt(c));
357+
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
358+
Settings settings = Settings.builder().build();
359+
threadContext = new ThreadContext(settings);
360+
when(executor.getClient()).thenReturn(client);
361+
when(client.threadPool()).thenReturn(threadPool);
362+
when(threadPool.getThreadContext()).thenReturn(threadContext);
363+
when(executor.getScriptService()).thenReturn(scriptService);
364+
365+
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
366+
executor
367+
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
368+
}
285369
}

0 commit comments

Comments
 (0)