|
15 | 15 | import org.opensearch.core.xcontent.ToXContentObject;
|
16 | 16 | import org.opensearch.core.xcontent.XContentBuilder;
|
17 | 17 | import org.opensearch.core.xcontent.XContentParser;
|
| 18 | +import org.opensearch.ml.common.connector.AbstractConnector; |
18 | 19 | import org.opensearch.ml.common.connector.Connector;
|
19 | 20 | import org.opensearch.ml.common.model.Guardrails;
|
20 | 21 | import org.opensearch.ml.common.model.MLDeploySetting;
|
|
29 | 30 | import java.io.IOException;
|
30 | 31 | import java.time.Instant;
|
31 | 32 | import java.util.ArrayList;
|
| 33 | +import java.util.HashMap; |
32 | 34 | import java.util.List;
|
33 | 35 | import java.util.Locale;
|
| 36 | +import java.util.Map; |
34 | 37 |
|
35 | 38 | import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
|
36 | 39 | import static org.opensearch.ml.common.CommonValue.USER;
|
37 | 40 | import static org.opensearch.ml.common.connector.Connector.createConnector;
|
| 41 | +import static org.opensearch.ml.common.utils.StringUtils.filterInterfaceMap; |
38 | 42 |
|
39 | 43 | @Getter
|
40 | 44 | public class MLModel implements ToXContentObject {
|
@@ -89,6 +93,7 @@ public class MLModel implements ToXContentObject {
|
89 | 93 | public static final String CONNECTOR_FIELD = "connector";
|
90 | 94 | public static final String CONNECTOR_ID_FIELD = "connector_id";
|
91 | 95 | public static final String GUARDRAILS_FIELD = "guardrails";
|
| 96 | + public static final String MODEL_INTERFACE_FIELD = "model_interface"; |
92 | 97 |
|
93 | 98 | private String name;
|
94 | 99 | private String modelGroupId;
|
@@ -134,6 +139,8 @@ public class MLModel implements ToXContentObject {
|
134 | 139 | private String connectorId;
|
135 | 140 | private Guardrails guardrails;
|
136 | 141 |
|
| 142 | + private Map<String, String> modelInterface; |
| 143 | + |
137 | 144 | @Builder(toBuilder = true)
|
138 | 145 | public MLModel(String name,
|
139 | 146 | String modelGroupId,
|
@@ -166,7 +173,8 @@ public MLModel(String name,
|
166 | 173 | Boolean isHidden,
|
167 | 174 | Connector connector,
|
168 | 175 | String connectorId,
|
169 |
| - Guardrails guardrails) { |
| 176 | + Guardrails guardrails, |
| 177 | + Map<String, String> modelInterface) { |
170 | 178 | this.name = name;
|
171 | 179 | this.modelGroupId = modelGroupId;
|
172 | 180 | this.algorithm = algorithm;
|
@@ -200,6 +208,7 @@ public MLModel(String name,
|
200 | 208 | this.connector = connector;
|
201 | 209 | this.connectorId = connectorId;
|
202 | 210 | this.guardrails = guardrails;
|
| 211 | + this.modelInterface = modelInterface; |
203 | 212 | }
|
204 | 213 |
|
205 | 214 | public MLModel(StreamInput input) throws IOException {
|
@@ -261,6 +270,9 @@ public MLModel(StreamInput input) throws IOException {
|
261 | 270 | if (input.readBoolean()) {
|
262 | 271 | this.guardrails = new Guardrails(input);
|
263 | 272 | }
|
| 273 | + if (input.readBoolean()) { |
| 274 | + modelInterface = input.readMap(StreamInput::readString, StreamInput::readString); |
| 275 | + } |
264 | 276 | }
|
265 | 277 | }
|
266 | 278 |
|
@@ -338,6 +350,12 @@ public void writeTo(StreamOutput out) throws IOException {
|
338 | 350 | } else {
|
339 | 351 | out.writeBoolean(false);
|
340 | 352 | }
|
| 353 | + if (modelInterface != null) { |
| 354 | + out.writeBoolean(true); |
| 355 | + out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString); |
| 356 | + } else { |
| 357 | + out.writeBoolean(false); |
| 358 | + } |
341 | 359 | }
|
342 | 360 |
|
343 | 361 | @Override
|
@@ -442,6 +460,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
|
442 | 460 | if (guardrails != null) {
|
443 | 461 | builder.field(GUARDRAILS_FIELD, guardrails);
|
444 | 462 | }
|
| 463 | + if (modelInterface != null) { |
| 464 | + builder.field(MODEL_INTERFACE_FIELD, modelInterface); |
| 465 | + } |
445 | 466 | builder.endObject();
|
446 | 467 | return builder;
|
447 | 468 | }
|
@@ -486,6 +507,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
|
486 | 507 | Connector connector = null;
|
487 | 508 | String connectorId = null;
|
488 | 509 | Guardrails guardrails = null;
|
| 510 | + Map<String, String> modelInterface = new HashMap<>(); |
489 | 511 |
|
490 | 512 | ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
|
491 | 513 | while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
|
@@ -617,6 +639,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
|
617 | 639 | case GUARDRAILS_FIELD:
|
618 | 640 | guardrails = Guardrails.parse(parser);
|
619 | 641 | break;
|
| 642 | + case MODEL_INTERFACE_FIELD: |
| 643 | + modelInterface = filterInterfaceMap(parser.map()); |
| 644 | + break; |
620 | 645 | default:
|
621 | 646 | parser.skipChildren();
|
622 | 647 | break;
|
@@ -656,6 +681,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
|
656 | 681 | .connector(connector)
|
657 | 682 | .connectorId(connectorId)
|
658 | 683 | .guardrails(guardrails)
|
| 684 | + .modelInterface(modelInterface) |
659 | 685 | .build();
|
660 | 686 | }
|
661 | 687 |
|
|
0 commit comments