Skip to content

Commit 69d1275

Browse files
committed
Move model interface from connector to model
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
1 parent 2b6c969 commit 69d1275

File tree

18 files changed

+275
-157
lines changed

18 files changed

+275
-157
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

+42-43
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class CommonValue {
5656
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
5757
public static final String ML_TASK_INDEX = ".plugins-ml-task";
5858
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
59-
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
59+
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
6060
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
6161
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
6262
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
@@ -160,11 +160,7 @@ public class CommonValue {
160160
+ "\" : {\"type\": \"flat_object\"},\n"
161161
+ " \""
162162
+ AbstractConnector.ACTIONS_FIELD
163-
+ "\" : {\"type\": \"flat_object\"},\n"
164-
+ " \""
165-
+ AbstractConnector.MODEL_INTERFACE_FIELD
166-
+ "\" : {\"type\": \"flat_object\"}\n";
167-
163+
+ "\" : {\"type\": \"flat_object\"},\n";
168164

169165
public static final String ML_MODEL_INDEX_MAPPING = "{\n"
170166
+ " \"_meta\": {\"schema_version\": "
@@ -269,45 +265,48 @@ public class CommonValue {
269265
+ MLModel.LAST_UNDEPLOYED_TIME_FIELD
270266
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
271267
+ " \""
268+
+ MLModel.MODEL_INTERFACE_FIELD
269+
+ "\" : {\"type\": \"flat_object\"},\n"
270+
+ " \""
272271
+ MLModel.GUARDRAILS_FIELD
273272
+ "\" : {\n" +
274-
" \"properties\": {\n" +
275-
" \"input_guardrail\": {\n" +
276-
" \"properties\": {\n" +
277-
" \"regex\": {\n" +
278-
" \"type\": \"text\"\n" +
279-
" },\n" +
280-
" \"stop_words\": {\n" +
281-
" \"properties\": {\n" +
282-
" \"index_name\": {\n" +
283-
" \"type\": \"text\"\n" +
284-
" },\n" +
285-
" \"source_fields\": {\n" +
286-
" \"type\": \"text\"\n" +
287-
" }\n" +
288-
" }\n" +
289-
" }\n" +
290-
" }\n" +
291-
" },\n" +
292-
" \"output_guardrail\": {\n" +
293-
" \"properties\": {\n" +
294-
" \"regex\": {\n" +
295-
" \"type\": \"text\"\n" +
296-
" },\n" +
297-
" \"stop_words\": {\n" +
298-
" \"properties\": {\n" +
299-
" \"index_name\": {\n" +
300-
" \"type\": \"text\"\n" +
301-
" },\n" +
302-
" \"source_fields\": {\n" +
303-
" \"type\": \"text\"\n" +
304-
" }\n" +
305-
" }\n" +
306-
" }\n" +
307-
" }\n" +
308-
" }\n" +
309-
" }\n" +
310-
" },\n"
273+
" \"properties\": {\n" +
274+
" \"input_guardrail\": {\n" +
275+
" \"properties\": {\n" +
276+
" \"regex\": {\n" +
277+
" \"type\": \"text\"\n" +
278+
" },\n" +
279+
" \"stop_words\": {\n" +
280+
" \"properties\": {\n" +
281+
" \"index_name\": {\n" +
282+
" \"type\": \"text\"\n" +
283+
" },\n" +
284+
" \"source_fields\": {\n" +
285+
" \"type\": \"text\"\n" +
286+
" }\n" +
287+
" }\n" +
288+
" }\n" +
289+
" }\n" +
290+
" },\n" +
291+
" \"output_guardrail\": {\n" +
292+
" \"properties\": {\n" +
293+
" \"regex\": {\n" +
294+
" \"type\": \"text\"\n" +
295+
" },\n" +
296+
" \"stop_words\": {\n" +
297+
" \"properties\": {\n" +
298+
" \"index_name\": {\n" +
299+
" \"type\": \"text\"\n" +
300+
" },\n" +
301+
" \"source_fields\": {\n" +
302+
" \"type\": \"text\"\n" +
303+
" }\n" +
304+
" }\n" +
305+
" }\n" +
306+
" }\n" +
307+
" }\n" +
308+
" }\n" +
309+
" },\n"
311310
+ " \""
312311
+ MLModel.CONNECTOR_FIELD
313312
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"

common/src/main/java/org/opensearch/ml/common/MLModel.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.core.xcontent.ToXContentObject;
1616
import org.opensearch.core.xcontent.XContentBuilder;
1717
import org.opensearch.core.xcontent.XContentParser;
18+
import org.opensearch.ml.common.connector.AbstractConnector;
1819
import org.opensearch.ml.common.connector.Connector;
1920
import org.opensearch.ml.common.model.Guardrails;
2021
import org.opensearch.ml.common.model.MLDeploySetting;
@@ -29,12 +30,15 @@
2930
import java.io.IOException;
3031
import java.time.Instant;
3132
import java.util.ArrayList;
33+
import java.util.HashMap;
3234
import java.util.List;
3335
import java.util.Locale;
36+
import java.util.Map;
3437

3538
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
3639
import static org.opensearch.ml.common.CommonValue.USER;
3740
import static org.opensearch.ml.common.connector.Connector.createConnector;
41+
import static org.opensearch.ml.common.utils.StringUtils.filterInterfaceMap;
3842

3943
@Getter
4044
public class MLModel implements ToXContentObject {
@@ -89,6 +93,7 @@ public class MLModel implements ToXContentObject {
8993
public static final String CONNECTOR_FIELD = "connector";
9094
public static final String CONNECTOR_ID_FIELD = "connector_id";
9195
public static final String GUARDRAILS_FIELD = "guardrails";
96+
public static final String MODEL_INTERFACE_FIELD = "model_interface";
9297

9398
private String name;
9499
private String modelGroupId;
@@ -134,6 +139,8 @@ public class MLModel implements ToXContentObject {
134139
private String connectorId;
135140
private Guardrails guardrails;
136141

142+
private Map<String, String> modelInterface;
143+
137144
@Builder(toBuilder = true)
138145
public MLModel(String name,
139146
String modelGroupId,
@@ -166,7 +173,8 @@ public MLModel(String name,
166173
Boolean isHidden,
167174
Connector connector,
168175
String connectorId,
169-
Guardrails guardrails) {
176+
Guardrails guardrails,
177+
Map<String, String> modelInterface) {
170178
this.name = name;
171179
this.modelGroupId = modelGroupId;
172180
this.algorithm = algorithm;
@@ -200,6 +208,7 @@ public MLModel(String name,
200208
this.connector = connector;
201209
this.connectorId = connectorId;
202210
this.guardrails = guardrails;
211+
this.modelInterface = modelInterface;
203212
}
204213

205214
public MLModel(StreamInput input) throws IOException {
@@ -261,6 +270,9 @@ public MLModel(StreamInput input) throws IOException {
261270
if (input.readBoolean()) {
262271
this.guardrails = new Guardrails(input);
263272
}
273+
if (input.readBoolean()) {
274+
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
275+
}
264276
}
265277
}
266278

@@ -338,6 +350,12 @@ public void writeTo(StreamOutput out) throws IOException {
338350
} else {
339351
out.writeBoolean(false);
340352
}
353+
if (modelInterface != null) {
354+
out.writeBoolean(true);
355+
out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
356+
} else {
357+
out.writeBoolean(false);
358+
}
341359
}
342360

343361
@Override
@@ -442,6 +460,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
442460
if (guardrails != null) {
443461
builder.field(GUARDRAILS_FIELD, guardrails);
444462
}
463+
if (modelInterface != null) {
464+
builder.field(MODEL_INTERFACE_FIELD, modelInterface);
465+
}
445466
builder.endObject();
446467
return builder;
447468
}
@@ -486,6 +507,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
486507
Connector connector = null;
487508
String connectorId = null;
488509
Guardrails guardrails = null;
510+
Map<String, String> modelInterface = new HashMap<>();
489511

490512
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
491513
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -617,6 +639,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
617639
case GUARDRAILS_FIELD:
618640
guardrails = Guardrails.parse(parser);
619641
break;
642+
case MODEL_INTERFACE_FIELD:
643+
modelInterface = filterInterfaceMap(parser.map());
644+
break;
620645
default:
621646
parser.skipChildren();
622647
break;
@@ -656,6 +681,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
656681
.connector(connector)
657682
.connectorId(connectorId)
658683
.guardrails(guardrails)
684+
.modelInterface(modelInterface)
659685
.build();
660686
}
661687

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ public abstract class AbstractConnector implements Connector {
4444
public static final String OWNER_FIELD = "owner";
4545
public static final String ACCESS_FIELD = "access";
4646
public static final String CLIENT_CONFIG_FIELD = "client_config";
47-
public static final String MODEL_INTERFACE_FIELD = "model_interface";
4847

4948

5049
protected String name;
@@ -70,7 +69,6 @@ public abstract class AbstractConnector implements Connector {
7069
protected Instant lastUpdateTime;
7170
@Setter
7271
protected ConnectorClientConfig connectorClientConfig;
73-
protected Map<String, String> modelInterface;
7472

7573
protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
7674
if (headers == null) {

common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public class AwsConnector extends HttpConnector {
3232
public AwsConnector(String name, String description, String version, String protocol,
3333
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
3434
List<String> backendRoles, AccessMode accessMode, User owner,
35-
ConnectorClientConfig connectorClientConfig, Map<String, String> modelInterface) {
35+
ConnectorClientConfig connectorClientConfig) {
3636
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode,
37-
owner, connectorClientConfig, modelInterface);
37+
owner, connectorClientConfig);
3838
validate();
3939
}
4040

common/src/main/java/org/opensearch/ml/common/connector/Connector.java

-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ public interface Connector extends ToXContentObject, Writeable {
6060

6161
String getPredictHttpMethod();
6262

63-
Map<String, String> getModelInterface();
64-
6563
<T> T createPredictPayload(Map<String, String> parameters);
6664

6765
void decrypt(Function<String, String> function);

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+2-22
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
3535
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
3636
import static org.opensearch.ml.common.utils.StringUtils.isJson;
37-
import static org.opensearch.ml.common.utils.StringUtils.filterInterfaceMap;
3837
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
3938

4039
@Log4j2
@@ -54,7 +53,7 @@ public class HttpConnector extends AbstractConnector {
5453
public HttpConnector(String name, String description, String version, String protocol,
5554
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
5655
List<String> backendRoles, AccessMode accessMode, User owner,
57-
ConnectorClientConfig connectorClientConfig, Map<String, String> modelInterface) {
56+
ConnectorClientConfig connectorClientConfig) {
5857
validateProtocol(protocol);
5958
this.name = name;
6059
this.description = description;
@@ -67,7 +66,6 @@ public HttpConnector(String name, String description, String version, String pro
6766
this.access = accessMode;
6867
this.owner = owner;
6968
this.connectorClientConfig = connectorClientConfig;
70-
this.modelInterface = modelInterface;
7169

7270
}
7371

@@ -129,9 +127,6 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
129127
case CLIENT_CONFIG_FIELD:
130128
connectorClientConfig = ConnectorClientConfig.parse(parser);
131129
break;
132-
case MODEL_INTERFACE_FIELD:
133-
modelInterface = filterInterfaceMap(parser.map());
134-
break;
135130
default:
136131
parser.skipChildren();
137132
break;
@@ -181,9 +176,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
181176
if (connectorClientConfig != null) {
182177
builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig);
183178
}
184-
if (modelInterface != null) {
185-
builder.field(MODEL_INTERFACE_FIELD, modelInterface);
186-
}
187179
builder.endObject();
188180
return builder;
189181
}
@@ -227,9 +219,6 @@ private void parseFromStream(StreamInput input) throws IOException {
227219
if (input.readBoolean()) {
228220
this.connectorClientConfig = new ConnectorClientConfig(input);
229221
}
230-
if (input.readBoolean()) {
231-
this.modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
232-
}
233222
}
234223

235224
@Override
@@ -280,12 +269,6 @@ public void writeTo(StreamOutput out) throws IOException {
280269
} else {
281270
out.writeBoolean(false);
282271
}
283-
if (modelInterface != null) {
284-
out.writeBoolean(true);
285-
out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
286-
} else {
287-
out.writeBoolean(false);
288-
}
289272
}
290273

291274
@Override
@@ -321,9 +304,6 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
321304
if (updateContent.getConnectorClientConfig() != null) {
322305
this.connectorClientConfig = updateContent.getConnectorClientConfig();
323306
}
324-
if (updateContent.getModelInterface() != null && updateContent.getModelInterface().size() > 0) {
325-
this.modelInterface = updateContent.getModelInterface();
326-
}
327307
}
328308

329309
@Override
@@ -381,7 +361,7 @@ public void decrypt(Function<String, String> function) {
381361

382362
@Override
383363
public Connector cloneConnector() {
384-
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){
364+
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) {
385365
this.writeTo(bytesStreamOutput);
386366
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
387367
return new HttpConnector(streamInput);

0 commit comments

Comments
 (0)