18
18
import java .util .Optional ;
19
19
20
20
import org .opensearch .client .node .NodeClient ;
21
+ import org .opensearch .common .util .concurrent .ThreadContext ;
21
22
import org .opensearch .core .action .ActionListener ;
22
23
import org .opensearch .core .rest .RestStatus ;
23
24
import org .opensearch .core .xcontent .XContentParser ;
24
25
import org .opensearch .ml .common .FunctionName ;
25
26
import org .opensearch .ml .common .MLModel ;
26
27
import org .opensearch .ml .common .input .MLInput ;
27
- import org .opensearch .ml .common .transport .model .MLModelGetAction ;
28
- import org .opensearch .ml .common .transport .model .MLModelGetRequest ;
29
- import org .opensearch .ml .common .transport .model .MLModelGetResponse ;
30
28
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
31
29
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskRequest ;
32
30
import org .opensearch .ml .model .MLModelManager ;
@@ -91,9 +89,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
91
89
}
92
90
93
91
return channel -> {
94
- MLModelGetRequest getModelRequest = new MLModelGetRequest (modelId , false );
95
- ActionListener <MLModelGetResponse > listener = ActionListener .wrap (r -> {
96
- MLModel mlModel = r .getMlModel ();
92
+ ActionListener <MLModel > listener = ActionListener .wrap (mlModel -> {
97
93
String algoName = mlModel .getAlgorithm ().name ();
98
94
client
99
95
.execute (
@@ -109,8 +105,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
109
105
log .error ("Failed to send error response" , ex );
110
106
}
111
107
});
112
- client .execute (MLModelGetAction .INSTANCE , getModelRequest , listener );
113
-
108
+ try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
109
+ modelManager .getModel (modelId , ActionListener .runBefore (listener , () -> context .restore ()));
110
+ }
114
111
};
115
112
}
116
113
0 commit comments