Skip to content

Commit 500fff9

Browse files
authored
Get model group API (opensearch-project#1670)
* Get model group API Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 7cc9399 commit 500fff9

15 files changed

+1118
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.model_group;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLModelGroupGetAction extends ActionType<MLModelGroupGetResponse> {
11+
public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get";
13+
14+
private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.model_group;
7+
8+
import lombok.AccessLevel;
9+
import lombok.Builder;
10+
import lombok.Getter;
11+
import lombok.ToString;
12+
import lombok.experimental.FieldDefaults;
13+
import org.opensearch.action.ActionRequest;
14+
import org.opensearch.action.ActionRequestValidationException;
15+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
16+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
17+
import org.opensearch.core.common.io.stream.StreamInput;
18+
import org.opensearch.core.common.io.stream.StreamOutput;
19+
20+
import java.io.ByteArrayInputStream;
21+
import java.io.ByteArrayOutputStream;
22+
import java.io.IOException;
23+
import java.io.UncheckedIOException;
24+
25+
import static org.opensearch.action.ValidateActions.addValidationError;
26+
27+
@Getter
28+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
29+
@ToString
30+
public class MLModelGroupGetRequest extends ActionRequest {
31+
32+
String modelGroupId;
33+
34+
@Builder
35+
public MLModelGroupGetRequest(String modelGroupId) {
36+
this.modelGroupId = modelGroupId;
37+
}
38+
39+
public MLModelGroupGetRequest(StreamInput in) throws IOException {
40+
super(in);
41+
this.modelGroupId = in.readString();
42+
}
43+
44+
@Override
45+
public void writeTo(StreamOutput out) throws IOException {
46+
super.writeTo(out);
47+
out.writeString(this.modelGroupId);
48+
}
49+
50+
@Override
51+
public ActionRequestValidationException validate() {
52+
ActionRequestValidationException exception = null;
53+
54+
if (this.modelGroupId == null) {
55+
exception = addValidationError("Model group id can't be null", exception);
56+
}
57+
58+
return exception;
59+
}
60+
61+
public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) {
62+
if (actionRequest instanceof MLModelGroupGetRequest) {
63+
return (MLModelGroupGetRequest)actionRequest;
64+
}
65+
66+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
67+
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
68+
actionRequest.writeTo(osso);
69+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
70+
return new MLModelGroupGetRequest(input);
71+
}
72+
} catch (IOException e) {
73+
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e);
74+
}
75+
}
76+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.model_group;
7+
8+
import lombok.Builder;
9+
import lombok.Getter;
10+
import lombok.ToString;
11+
import org.opensearch.core.action.ActionResponse;
12+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
13+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
14+
import org.opensearch.core.common.io.stream.StreamInput;
15+
import org.opensearch.core.common.io.stream.StreamOutput;
16+
import org.opensearch.core.xcontent.ToXContentObject;
17+
import org.opensearch.core.xcontent.XContentBuilder;
18+
import org.opensearch.ml.common.MLModelGroup;
19+
20+
import java.io.ByteArrayInputStream;
21+
import java.io.ByteArrayOutputStream;
22+
import java.io.IOException;
23+
import java.io.UncheckedIOException;
24+
25+
@Getter
26+
@ToString
27+
public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject {
28+
29+
MLModelGroup mlModelGroup;
30+
31+
@Builder
32+
public MLModelGroupGetResponse(MLModelGroup mlModelGroup) {
33+
this.mlModelGroup = mlModelGroup;
34+
}
35+
36+
37+
public MLModelGroupGetResponse(StreamInput in) throws IOException {
38+
super(in);
39+
mlModelGroup = mlModelGroup.fromStream(in);
40+
}
41+
42+
@Override
43+
public void writeTo(StreamOutput out) throws IOException{
44+
mlModelGroup.writeTo(out);
45+
}
46+
47+
@Override
48+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
49+
return mlModelGroup.toXContent(xContentBuilder, params);
50+
}
51+
52+
public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) {
53+
if (actionResponse instanceof MLModelGroupGetResponse) {
54+
return (MLModelGroupGetResponse) actionResponse;
55+
}
56+
57+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
58+
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
59+
actionResponse.writeTo(osso);
60+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
61+
return new MLModelGroupGetResponse(input);
62+
}
63+
} catch (IOException e) {
64+
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e);
65+
}
66+
}
67+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.model_group;
7+
8+
import org.junit.Before;
9+
import org.junit.Test;
10+
import org.opensearch.action.ActionRequest;
11+
import org.opensearch.action.ActionRequestValidationException;
12+
import org.opensearch.common.io.stream.BytesStreamOutput;
13+
import org.opensearch.core.common.io.stream.StreamOutput;
14+
15+
import java.io.IOException;
16+
import java.io.UncheckedIOException;
17+
18+
import static org.junit.Assert.assertEquals;
19+
import static org.junit.Assert.assertNotSame;
20+
import static org.junit.Assert.assertNull;
21+
import static org.junit.Assert.assertSame;
22+
23+
public class MLModelGroupGetRequestTest {
24+
private String modelGroupId;
25+
26+
@Before
27+
public void setUp() {
28+
modelGroupId = "test_id";
29+
}
30+
31+
@Test
32+
public void writeTo_Success() throws IOException {
33+
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
34+
.modelGroupId(modelGroupId).build();
35+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
36+
mlModelGroupGetRequest.writeTo(bytesStreamOutput);
37+
MLModelGroupGetRequest parsedModel = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput());
38+
assertEquals(parsedModel.getModelGroupId(), modelGroupId);
39+
}
40+
41+
@Test
42+
public void validate_Exception_NullmodelGroupId() {
43+
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().build();
44+
45+
ActionRequestValidationException exception = mlModelGroupGetRequest.validate();
46+
assertEquals("Validation Failed: 1: Model group id can't be null;", exception.getMessage());
47+
}
48+
49+
@Test
50+
public void fromActionRequest_Success() {
51+
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
52+
.modelGroupId(modelGroupId).build();
53+
ActionRequest actionRequest = new ActionRequest() {
54+
@Override
55+
public ActionRequestValidationException validate() {
56+
return null;
57+
}
58+
59+
@Override
60+
public void writeTo(StreamOutput out) throws IOException {
61+
mlModelGroupGetRequest.writeTo(out);
62+
}
63+
};
64+
MLModelGroupGetRequest result = MLModelGroupGetRequest.fromActionRequest(actionRequest);
65+
assertNotSame(result, mlModelGroupGetRequest);
66+
assertEquals(result.getModelGroupId(), mlModelGroupGetRequest.getModelGroupId());
67+
}
68+
69+
@Test(expected = UncheckedIOException.class)
70+
public void fromActionRequest_IOException() {
71+
ActionRequest actionRequest = new ActionRequest() {
72+
@Override
73+
public ActionRequestValidationException validate() {
74+
return null;
75+
}
76+
77+
@Override
78+
public void writeTo(StreamOutput out) throws IOException {
79+
throw new IOException("test");
80+
}
81+
};
82+
MLModelGroupGetRequest.fromActionRequest(actionRequest);
83+
}
84+
85+
@Test
86+
public void validate_Success() {
87+
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
88+
ActionRequestValidationException actionRequestValidationException = mlModelGroupGetRequest.validate();
89+
assertNull(actionRequestValidationException);
90+
}
91+
92+
@Test
93+
public void fromActionRequestWithMLModelGroupGetRequest_Success() {
94+
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
95+
MLModelGroupGetRequest mlModelGroupGetRequestFromActionRequest = MLModelGroupGetRequest.fromActionRequest(mlModelGroupGetRequest);
96+
assertSame(mlModelGroupGetRequest, mlModelGroupGetRequestFromActionRequest);
97+
assertEquals(mlModelGroupGetRequest.getModelGroupId(), mlModelGroupGetRequestFromActionRequest.getModelGroupId());
98+
}
99+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.model_group;
7+
8+
import org.junit.Before;
9+
import org.junit.Test;
10+
import org.opensearch.common.io.stream.BytesStreamOutput;
11+
import org.opensearch.common.xcontent.XContentType;
12+
import org.opensearch.core.action.ActionResponse;
13+
import org.opensearch.core.common.io.stream.StreamOutput;
14+
import org.opensearch.core.xcontent.MediaTypeRegistry;
15+
import org.opensearch.core.xcontent.ToXContent;
16+
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.ml.common.MLModelGroup;
18+
19+
import java.io.IOException;
20+
import java.io.UncheckedIOException;
21+
22+
import static org.junit.Assert.assertEquals;
23+
import static org.junit.Assert.assertNotEquals;
24+
import static org.junit.Assert.assertNotNull;
25+
import static org.junit.Assert.assertNotSame;
26+
import static org.junit.Assert.assertSame;
27+
28+
public class MLModelGroupGetResponseTest {
29+
30+
MLModelGroup mlModelGroup;
31+
32+
@Before
33+
public void setUp() {
34+
mlModelGroup = MLModelGroup.builder()
35+
.name("modelGroup1")
36+
.latestVersion(1)
37+
.description("This is an example model group")
38+
.access("public")
39+
.build();
40+
}
41+
42+
@Test
43+
public void writeTo_Success() throws IOException {
44+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
45+
MLModelGroupGetResponse response = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
46+
response.writeTo(bytesStreamOutput);
47+
MLModelGroupGetResponse parsedResponse = new MLModelGroupGetResponse(bytesStreamOutput.bytes().streamInput());
48+
assertNotEquals(response.mlModelGroup, parsedResponse.mlModelGroup);
49+
assertEquals(response.mlModelGroup.getName(), parsedResponse.mlModelGroup.getName());
50+
assertEquals(response.mlModelGroup.getDescription(), parsedResponse.mlModelGroup.getDescription());
51+
assertEquals(response.mlModelGroup.getLatestVersion(), parsedResponse.mlModelGroup.getLatestVersion());
52+
}
53+
54+
@Test
55+
public void toXContentTest() throws IOException {
56+
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
57+
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
58+
mlModelGroupGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
59+
assertNotNull(builder);
60+
String jsonStr = builder.toString();
61+
assertEquals("{\"name\":\"modelGroup1\"," +
62+
"\"latest_version\":1," +
63+
"\"description\":\"This is an example model group\"," +
64+
"\"access\":\"public\"}",
65+
jsonStr);
66+
}
67+
68+
@Test
69+
public void fromActionResponseWithMLModelGroupGetResponse_Success() {
70+
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
71+
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(mlModelGroupGetResponse);
72+
assertSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
73+
assertEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
74+
}
75+
76+
@Test
77+
public void fromActionResponse_Success() {
78+
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
79+
ActionResponse actionResponse = new ActionResponse() {
80+
@Override
81+
public void writeTo(StreamOutput out) throws IOException {
82+
mlModelGroupGetResponse.writeTo(out);
83+
}
84+
};
85+
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(actionResponse);
86+
assertNotSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
87+
assertNotEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
88+
}
89+
90+
@Test(expected = UncheckedIOException.class)
91+
public void fromActionResponse_IOException() {
92+
ActionResponse actionResponse = new ActionResponse() {
93+
@Override
94+
public void writeTo(StreamOutput out) throws IOException {
95+
throw new IOException();
96+
}
97+
};
98+
MLModelGroupGetResponse.fromActionResponse(actionResponse);
99+
}
100+
}

0 commit comments

Comments
 (0)