Skip to content

Commit 26c7f07

Browse files
committed
ML Model Interface (opensearch-project#2357)
* ML Interface poc Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Style fix Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Adjust model interface minimal support version Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix '*' import Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix UT Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Address review concern Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Change json schema pacakge to highest stars library Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * style fix Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Move model interface from connector to model index Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix compilation and UT Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Move schema validation to rest layer Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Remove unnecessary dependencies Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Initiate modelInterface to null object Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Change sout to log Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Remove debug info Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix minor styles Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix minor styles Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix minor styles Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix build Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix UTs Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix UT Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Validate whole output schema Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix doc Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Fix style Signed-off-by: Sicheng Song <sicheng.song@outlook.com> * Rebase Signed-off-by: Sicheng Song <sicheng.song@outlook.com> --------- Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 4edcd17 commit 26c7f07

25 files changed

+598
-209
lines changed

build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ buildscript {
2323
}
2424

2525
common_utils_version = System.getProperty("common_utils.version", opensearch_build)
26-
kotlin_version = System.getProperty("kotlin.version", "1.8.21")
26+
kotlin_version = System.getProperty("kotlin.version", "1.9.23")
2727
}
2828

2929
repositories {

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+92-92
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class CommonValue {
5656
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
5757
public static final String ML_TASK_INDEX = ".plugins-ml-task";
5858
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
59-
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
59+
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
6060
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
6161
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
6262
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
@@ -82,59 +82,56 @@ public class CommonValue {
8282
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
8383
+ " }\n"
8484
+ " }\n";
85-
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" +
86-
" \"_meta\": {\n" +
87-
" \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" +
88-
" },\n" +
89-
" \"properties\": {\n" +
90-
" \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" +
91-
" \"type\": \"text\",\n" +
92-
" \"fields\": {\n" +
93-
" \"keyword\": {\n" +
94-
" \"type\": \"keyword\",\n" +
95-
" \"ignore_above\": 256\n" +
96-
" }\n" +
97-
" }\n" +
98-
" },\n" +
99-
" \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" +
100-
" \"type\": \"text\"\n" +
101-
" },\n" +
102-
" \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" +
103-
" \"type\": \"integer\"\n" +
104-
" },\n" +
105-
" \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" +
106-
" \"type\": \"keyword\"\n" +
107-
" },\n" +
108-
" \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" +
109-
" \"type\": \"text\",\n" +
110-
" \"fields\": {\n" +
111-
" \"keyword\": {\n" +
112-
" \"type\": \"keyword\",\n" +
113-
" \"ignore_above\": 256\n" +
114-
" }\n" +
115-
" }\n" +
116-
" },\n" +
117-
" \"" + MLModelGroup.ACCESS + "\": {\n" +
118-
" \"type\": \"keyword\"\n" +
119-
" },\n" +
120-
" \"" + MLModelGroup.OWNER + "\": {\n" +
121-
" \"type\": \"nested\",\n" +
122-
" \"properties\": {\n" +
123-
" \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n"
124-
+
125-
" \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
126-
+
127-
" \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
128-
" \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
129-
+
130-
" }\n" +
131-
" },\n" +
132-
" \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" +
133-
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" +
134-
" \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" +
135-
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" +
136-
" }\n" +
137-
"}";
85+
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n"
86+
+ " \"_meta\": {\n"
87+
+ " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n"
88+
+ " },\n"
89+
+ " \"properties\": {\n"
90+
+ " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n"
91+
+ " \"type\": \"text\",\n"
92+
+ " \"fields\": {\n"
93+
+ " \"keyword\": {\n"
94+
+ " \"type\": \"keyword\",\n"
95+
+ " \"ignore_above\": 256\n"
96+
+ " }\n"
97+
+ " }\n"
98+
+ " },\n"
99+
+ " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n"
100+
+ " \"type\": \"text\"\n"
101+
+ " },\n"
102+
+ " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n"
103+
+ " \"type\": \"integer\"\n"
104+
+ " },\n"
105+
+ " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n"
106+
+ " \"type\": \"keyword\"\n"
107+
+ " },\n"
108+
+ " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n"
109+
+ " \"type\": \"text\",\n"
110+
+ " \"fields\": {\n"
111+
+ " \"keyword\": {\n"
112+
+ " \"type\": \"keyword\",\n"
113+
+ " \"ignore_above\": 256\n"
114+
+ " }\n"
115+
+ " }\n"
116+
+ " },\n"
117+
+ " \"" + MLModelGroup.ACCESS + "\": {\n"
118+
+ " \"type\": \"keyword\"\n"
119+
+ " },\n"
120+
+ " \"" + MLModelGroup.OWNER + "\": {\n"
121+
+ " \"type\": \"nested\",\n"
122+
+ " \"properties\": {\n"
123+
+ " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n"
124+
+ " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
125+
+ " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
126+
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
127+
+ " }\n"
128+
+ " },\n"
129+
+ " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n"
130+
+ " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
131+
+ " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n"
132+
+ " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
133+
+ " }\n"
134+
+ "}";
138135

139136
public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n"
140137
+ " \""
@@ -265,45 +262,48 @@ public class CommonValue {
265262
+ MLModel.LAST_UNDEPLOYED_TIME_FIELD
266263
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
267264
+ " \""
265+
+ MLModel.INTERFACE_FIELD
266+
+ "\": {\"type\": \"flat_object\"},\n"
267+
+ " \""
268268
+ MLModel.GUARDRAILS_FIELD
269-
+ "\" : {\n" +
270-
" \"properties\": {\n" +
271-
" \"input_guardrail\": {\n" +
272-
" \"properties\": {\n" +
273-
" \"regex\": {\n" +
274-
" \"type\": \"text\"\n" +
275-
" },\n" +
276-
" \"stop_words\": {\n" +
277-
" \"properties\": {\n" +
278-
" \"index_name\": {\n" +
279-
" \"type\": \"text\"\n" +
280-
" },\n" +
281-
" \"source_fields\": {\n" +
282-
" \"type\": \"text\"\n" +
283-
" }\n" +
284-
" }\n" +
285-
" }\n" +
286-
" }\n" +
287-
" },\n" +
288-
" \"output_guardrail\": {\n" +
289-
" \"properties\": {\n" +
290-
" \"regex\": {\n" +
291-
" \"type\": \"text\"\n" +
292-
" },\n" +
293-
" \"stop_words\": {\n" +
294-
" \"properties\": {\n" +
295-
" \"index_name\": {\n" +
296-
" \"type\": \"text\"\n" +
297-
" },\n" +
298-
" \"source_fields\": {\n" +
299-
" \"type\": \"text\"\n" +
300-
" }\n" +
301-
" }\n" +
302-
" }\n" +
303-
" }\n" +
304-
" }\n" +
305-
" }\n" +
306-
" },\n"
269+
+ "\" : {\n"
270+
+ " \"properties\": {\n"
271+
+ " \"input_guardrail\": {\n"
272+
+ " \"properties\": {\n"
273+
+ " \"regex\": {\n"
274+
+ " \"type\": \"text\"\n"
275+
+ " },\n"
276+
+ " \"stop_words\": {\n"
277+
+ " \"properties\": {\n"
278+
+ " \"index_name\": {\n"
279+
+ " \"type\": \"text\"\n"
280+
+ " },\n"
281+
+ " \"source_fields\": {\n"
282+
+ " \"type\": \"text\"\n"
283+
+ " }\n"
284+
+ " }\n"
285+
+ " }\n"
286+
+ " }\n"
287+
+ " },\n"
288+
+ " \"output_guardrail\": {\n"
289+
+ " \"properties\": {\n"
290+
+ " \"regex\": {\n"
291+
+ " \"type\": \"text\"\n"
292+
+ " },\n"
293+
+ " \"stop_words\": {\n"
294+
+ " \"properties\": {\n"
295+
+ " \"index_name\": {\n"
296+
+ " \"type\": \"text\"\n"
297+
+ " },\n"
298+
+ " \"source_fields\": {\n"
299+
+ " \"type\": \"text\"\n"
300+
+ " }\n"
301+
+ " }\n"
302+
+ " }\n"
303+
+ " }\n"
304+
+ " }\n"
305+
+ " }\n"
306+
+ " },\n"
307307
+ " \""
308308
+ MLModel.CONNECTOR_FIELD
309309
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"

common/src/main/java/org/opensearch/ml/common/MLModel.java

+60-1
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,18 @@
2929
import java.io.IOException;
3030
import java.time.Instant;
3131
import java.util.ArrayList;
32+
import java.util.Arrays;
33+
import java.util.HashMap;
34+
import java.util.HashSet;
3235
import java.util.List;
3336
import java.util.Locale;
37+
import java.util.Map;
38+
import java.util.Set;
3439

3540
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
3641
import static org.opensearch.ml.common.CommonValue.USER;
3742
import static org.opensearch.ml.common.connector.Connector.createConnector;
43+
import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap;
3844

3945
@Getter
4046
public class MLModel implements ToXContentObject {
@@ -89,6 +95,9 @@ public class MLModel implements ToXContentObject {
8995
public static final String CONNECTOR_FIELD = "connector";
9096
public static final String CONNECTOR_ID_FIELD = "connector_id";
9197
public static final String GUARDRAILS_FIELD = "guardrails";
98+
public static final String INTERFACE_FIELD = "interface";
99+
100+
public static final Set<String> allowedInterfaceFieldKeys = new HashSet<>(Arrays.asList("input", "output"));
92101

93102
private String name;
94103
private String modelGroupId;
@@ -134,6 +143,36 @@ public class MLModel implements ToXContentObject {
134143
private String connectorId;
135144
private Guardrails guardrails;
136145

146+
/**
147+
* Model interface is a map that contains the input and output fields of the model, with JSON schema as the value.
148+
* Sample model interface:
149+
* {
150+
* "interface": {
151+
* "input": {
152+
* "properties": {
153+
* "parameters": {
154+
* "properties": {
155+
* "messages": {
156+
* "type": "string",
157+
* "description": "This is a test description field"
158+
* }
159+
* }
160+
* }
161+
* }
162+
* },
163+
* "output": {
164+
* "properties": {
165+
* "inference_results": {
166+
* "type": "array",
167+
* "description": "This is a test description field"
168+
* }
169+
* }
170+
* }
171+
* }
172+
* }
173+
*/
174+
private Map<String, String> modelInterface;
175+
137176
@Builder(toBuilder = true)
138177
public MLModel(String name,
139178
String modelGroupId,
@@ -166,7 +205,8 @@ public MLModel(String name,
166205
Boolean isHidden,
167206
Connector connector,
168207
String connectorId,
169-
Guardrails guardrails) {
208+
Guardrails guardrails,
209+
Map<String, String> modelInterface) {
170210
this.name = name;
171211
this.modelGroupId = modelGroupId;
172212
this.algorithm = algorithm;
@@ -200,6 +240,7 @@ public MLModel(String name,
200240
this.connector = connector;
201241
this.connectorId = connectorId;
202242
this.guardrails = guardrails;
243+
this.modelInterface = modelInterface;
203244
}
204245

205246
public MLModel(StreamInput input) throws IOException {
@@ -261,6 +302,9 @@ public MLModel(StreamInput input) throws IOException {
261302
if (input.readBoolean()) {
262303
this.guardrails = new Guardrails(input);
263304
}
305+
if (input.readBoolean()) {
306+
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
307+
}
264308
}
265309
}
266310

@@ -338,6 +382,12 @@ public void writeTo(StreamOutput out) throws IOException {
338382
} else {
339383
out.writeBoolean(false);
340384
}
385+
if (modelInterface != null) {
386+
out.writeBoolean(true);
387+
out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
388+
} else {
389+
out.writeBoolean(false);
390+
}
341391
}
342392

343393
@Override
@@ -442,6 +492,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
442492
if (guardrails != null) {
443493
builder.field(GUARDRAILS_FIELD, guardrails);
444494
}
495+
if (modelInterface != null) {
496+
builder.field(INTERFACE_FIELD, modelInterface);
497+
}
445498
builder.endObject();
446499
return builder;
447500
}
@@ -486,6 +539,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
486539
Connector connector = null;
487540
String connectorId = null;
488541
Guardrails guardrails = null;
542+
Map<String, String> modelInterface = null;
489543

490544
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
491545
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -617,6 +671,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
617671
case GUARDRAILS_FIELD:
618672
guardrails = Guardrails.parse(parser);
619673
break;
674+
case INTERFACE_FIELD:
675+
modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys);
676+
break;
620677
default:
621678
parser.skipChildren();
622679
break;
@@ -656,11 +713,13 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
656713
.connector(connector)
657714
.connectorId(connectorId)
658715
.guardrails(guardrails)
716+
.modelInterface(modelInterface)
659717
.build();
660718
}
661719

662720
public static MLModel fromStream(StreamInput in) throws IOException {
663721
MLModel mlModel = new MLModel(in);
664722
return mlModel;
665723
}
724+
666725
}

0 commit comments

Comments
 (0)