Skip to content

Commit c6b8b43

Browse files
committed
ML Interface poc
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 479830d commit c6b8b43

File tree

13 files changed

+285
-27
lines changed

13 files changed

+285
-27
lines changed

build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ buildscript {
2323
}
2424

2525
common_utils_version = System.getProperty("common_utils.version", opensearch_build)
26-
kotlin_version = System.getProperty("kotlin.version", "1.8.21")
26+
kotlin_version = System.getProperty("kotlin.version", "1.9.23")
2727
}
2828

2929
repositories {

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+4
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,12 @@ public class CommonValue {
160160
+ "\" : {\"type\": \"flat_object\"},\n"
161161
+ " \""
162162
+ AbstractConnector.ACTIONS_FIELD
163+
+ "\" : {\"type\": \"flat_object\"},\n"
164+
+ " \""
165+
+ AbstractConnector.MODEL_INTERFACE_FIELD
163166
+ "\" : {\"type\": \"flat_object\"}\n";
164167

168+
165169
public static final String ML_MODEL_INDEX_MAPPING = "{\n"
166170
+ " \"_meta\": {\"schema_version\": "
167171
+ ML_MODEL_INDEX_SCHEMA_VERSION

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public abstract class AbstractConnector implements Connector {
4444
public static final String OWNER_FIELD = "owner";
4545
public static final String ACCESS_FIELD = "access";
4646
public static final String CLIENT_CONFIG_FIELD = "client_config";
47+
public static final String MODEL_INTERFACE_FIELD = "model_interface";
4748

4849

4950
protected String name;
@@ -69,6 +70,7 @@ public abstract class AbstractConnector implements Connector {
6970
protected Instant lastUpdateTime;
7071
@Setter
7172
protected ConnectorClientConfig connectorClientConfig;
73+
protected Map<String, String> modelInterface;
7274

7375
protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
7476
if (headers == null) {

common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public class AwsConnector extends HttpConnector {
3232
public AwsConnector(String name, String description, String version, String protocol,
3333
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
3434
List<String> backendRoles, AccessMode accessMode, User owner,
35-
ConnectorClientConfig connectorClientConfig) {
35+
ConnectorClientConfig connectorClientConfig, Map<String, String> modelInterface) {
3636
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode,
37-
owner, connectorClientConfig);
37+
owner, connectorClientConfig, modelInterface);
3838
validate();
3939
}
4040

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public interface Connector extends ToXContentObject, Writeable {
6060

6161
String getPredictHttpMethod();
6262

63+
Map<String, String> getModelInterface();
64+
6365
<T> T createPredictPayload(Map<String, String> parameters);
6466

6567
void decrypt(Function<String, String> function);

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
1010
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
11-
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
12-
import static org.opensearch.ml.common.utils.StringUtils.isJson;
11+
import static org.opensearch.ml.common.utils.StringUtils.*;
1312

1413
import java.io.IOException;
1514
import java.time.Instant;
@@ -53,7 +52,7 @@ public class HttpConnector extends AbstractConnector {
5352
public HttpConnector(String name, String description, String version, String protocol,
5453
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
5554
List<String> backendRoles, AccessMode accessMode, User owner,
56-
ConnectorClientConfig connectorClientConfig) {
55+
ConnectorClientConfig connectorClientConfig, Map<String, String> modelInterface) {
5756
validateProtocol(protocol);
5857
this.name = name;
5958
this.description = description;
@@ -66,6 +65,7 @@ public HttpConnector(String name, String description, String version, String pro
6665
this.access = accessMode;
6766
this.owner = owner;
6867
this.connectorClientConfig = connectorClientConfig;
68+
this.modelInterface = modelInterface;
6969

7070
}
7171

@@ -127,6 +127,9 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
127127
case CLIENT_CONFIG_FIELD:
128128
connectorClientConfig = ConnectorClientConfig.parse(parser);
129129
break;
130+
case MODEL_INTERFACE_FIELD:
131+
modelInterface = getInterfaceMap(parser.map());
132+
break;
130133
default:
131134
parser.skipChildren();
132135
break;
@@ -176,6 +179,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
176179
if (connectorClientConfig != null) {
177180
builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig);
178181
}
182+
if (modelInterface != null) {
183+
builder.field(MODEL_INTERFACE_FIELD, modelInterface);
184+
}
179185
builder.endObject();
180186
return builder;
181187
}
@@ -219,6 +225,9 @@ private void parseFromStream(StreamInput input) throws IOException {
219225
if (input.readBoolean()) {
220226
this.connectorClientConfig = new ConnectorClientConfig(input);
221227
}
228+
if (input.readBoolean()) {
229+
this.modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
230+
}
222231
}
223232

224233
@Override
@@ -269,6 +278,12 @@ public void writeTo(StreamOutput out) throws IOException {
269278
} else {
270279
out.writeBoolean(false);
271280
}
281+
if (modelInterface != null) {
282+
out.writeBoolean(true);
283+
out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
284+
} else {
285+
out.writeBoolean(false);
286+
}
272287
}
273288

274289
@Override
@@ -304,6 +319,9 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
304319
if (updateContent.getConnectorClientConfig() != null) {
305320
this.connectorClientConfig = updateContent.getConnectorClientConfig();
306321
}
322+
if (updateContent.getModelInterface() != null && updateContent.getModelInterface().size() > 0) {
323+
this.modelInterface = updateContent.getModelInterface();
324+
}
307325
}
308326

309327
@Override

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

+27-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import java.util.Map;
2828

2929
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
30-
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
30+
import static org.opensearch.ml.common.utils.StringUtils.*;
3131

3232
@Data
3333
public class MLCreateConnectorInput implements ToXContentObject, Writeable {
@@ -47,6 +47,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
4747
public static final String DRY_RUN_FIELD = "dry_run";
4848

4949
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG = Version.V_2_13_0;
50+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_INTERFACE = Version.V_2_13_0;
5051

5152
public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector";
5253

@@ -63,6 +64,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
6364
private boolean dryRun;
6465
private boolean updateConnector;
6566
private ConnectorClientConfig connectorClientConfig;
67+
private Map<String, String> modelInterface;
6668

6769

6870
@Builder(toBuilder = true)
@@ -78,7 +80,8 @@ public MLCreateConnectorInput(String name,
7880
AccessMode access,
7981
boolean dryRun,
8082
boolean updateConnector,
81-
ConnectorClientConfig connectorClientConfig
83+
ConnectorClientConfig connectorClientConfig,
84+
Map<String, String> modelInterface
8285

8386
) {
8487
if (!dryRun && !updateConnector) {
@@ -105,6 +108,7 @@ public MLCreateConnectorInput(String name,
105108
this.dryRun = dryRun;
106109
this.updateConnector = updateConnector;
107110
this.connectorClientConfig = connectorClientConfig;
111+
this.modelInterface = modelInterface;
108112

109113
}
110114

@@ -125,6 +129,7 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
125129
AccessMode access = null;
126130
boolean dryRun = false;
127131
ConnectorClientConfig connectorClientConfig = null;
132+
Map<String, String> modelInterface = new HashMap<>();
128133

129134
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
130135
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -176,13 +181,16 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
176181
case AbstractConnector.CLIENT_CONFIG_FIELD:
177182
connectorClientConfig = ConnectorClientConfig.parse(parser);
178183
break;
184+
case AbstractConnector.MODEL_INTERFACE_FIELD:
185+
modelInterface = getInterfaceMap(parser.map());
186+
break;
179187
default:
180188
parser.skipChildren();
181189
break;
182190
}
183191
}
184192
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions,
185-
backendRoles, addAllBackendRoles, access, dryRun, updateConnector, connectorClientConfig);
193+
backendRoles, addAllBackendRoles, access, dryRun, updateConnector, connectorClientConfig, modelInterface);
186194
}
187195

188196
@Override
@@ -221,6 +229,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
221229
if (connectorClientConfig != null) {
222230
builder.field(AbstractConnector.CLIENT_CONFIG_FIELD, connectorClientConfig);
223231
}
232+
if (modelInterface != null) {
233+
builder.field(AbstractConnector.MODEL_INTERFACE_FIELD, modelInterface);
234+
}
224235
builder.endObject();
225236
return builder;
226237
}
@@ -276,6 +287,14 @@ public void writeTo(StreamOutput output) throws IOException {
276287
output.writeBoolean(false);
277288
}
278289
}
290+
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_INTERFACE)) {
291+
if (modelInterface != null) {
292+
output.writeBoolean(true);
293+
output.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
294+
} else {
295+
output.writeBoolean(false);
296+
}
297+
}
279298
}
280299

281300
public MLCreateConnectorInput(StreamInput input) throws IOException {
@@ -311,6 +330,10 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
311330
this.connectorClientConfig = new ConnectorClientConfig(input);
312331
}
313332
}
314-
333+
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_INTERFACE)) {
334+
if (input.readBoolean()) {
335+
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
336+
}
337+
}
315338
}
316339
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

+23-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
import java.security.AccessController;
2020
import java.security.PrivilegedActionException;
2121
import java.security.PrivilegedExceptionAction;
22-
import java.util.ArrayList;
23-
import java.util.HashMap;
24-
import java.util.List;
25-
import java.util.Map;
22+
import java.util.*;
2623
import java.util.regex.Matcher;
2724
import java.util.regex.Pattern;
2825

@@ -94,6 +91,28 @@ public static Map<String, Object> fromJson(String jsonStr, String defaultKey) {
9491
return result;
9592
}
9693

94+
public static Map<String, String> getInterfaceMap(Map<String, ?> interfaceObjs) {
95+
Map<String, String> parameters = new HashMap<>();
96+
for (String key : interfaceObjs.keySet()) {
97+
if (Objects.equals(key, "input") || Objects.equals(key, "output")) {
98+
Object value = interfaceObjs.get(key);
99+
try {
100+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
101+
if (value instanceof String) {
102+
parameters.put(key, (String)value);
103+
} else {
104+
parameters.put(key, gson.toJson(value));
105+
}
106+
return null;
107+
});
108+
} catch (PrivilegedActionException e) {
109+
throw new RuntimeException(e);
110+
}
111+
}
112+
}
113+
return parameters;
114+
}
115+
97116
@SuppressWarnings("removal")
98117
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
99118
Map<String, String> parameters = new HashMap<>();

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

+17
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,23 @@ public void getParameterMap() {
115115
Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5"));
116116
}
117117

118+
@Test
119+
public void getInterfaceMap() {
120+
Map<String, Object> parameters = new HashMap<>();
121+
parameters.put("input", "value1");
122+
parameters.put("output", 2);
123+
parameters.put("key3", 2.1);
124+
parameters.put("key4", new int[]{10, 20});
125+
parameters.put("key5", new Object[]{1.01, "abc"});
126+
Map<String, String> interfaceMap = StringUtils.getInterfaceMap(parameters);
127+
Assert.assertEquals(2, interfaceMap.size());
128+
Assert.assertEquals("value1", interfaceMap.get("input"));
129+
Assert.assertEquals("2", interfaceMap.get("output"));
130+
Assert.assertNull(interfaceMap.get("key3"));
131+
Assert.assertNull(interfaceMap.get("key4"));
132+
Assert.assertNull(interfaceMap.get("key5"));
133+
}
134+
118135
@Test
119136
public void processTextDocs() {
120137
List<String> processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]"));

ml-algorithms/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies {
2828
exclude group: "org.jetbrains", module: "annotations"
2929
}
3030
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
31+
implementation group: 'com.github.erosb', name: 'json-sKema', version: '0.15.0'
3132
implementation group: 'org.reflections', name: 'reflections', version: '0.9.12'
3233
implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1'
3334
implementation group: 'org.tribuo', name: 'tribuo-regression-sgd', version: '4.2.1'

0 commit comments

Comments
 (0)