Skip to content

Commit cf7243c

Browse files
committed
stash thread context before running forward action
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent de59efc commit cf7243c

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java

+19-14
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.opensearch.cluster.node.DiscoveryNodes;
2323
import org.opensearch.cluster.service.ClusterService;
2424
import org.opensearch.common.inject.Inject;
25+
import org.opensearch.common.util.concurrent.ThreadContext;
2526
import org.opensearch.core.action.ActionListener;
2627
import org.opensearch.core.common.io.stream.StreamInput;
2728
import org.opensearch.core.xcontent.NamedXContentRegistry;
@@ -161,13 +162,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
161162
.build();
162163
MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
163164

164-
transportService
165-
.sendRequest(
166-
getNodeById(coordinatingNodeId),
167-
MLForwardAction.NAME,
168-
deployModelDoneMessage,
169-
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
170-
);
165+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
166+
transportService
167+
.sendRequest(
168+
getNodeById(coordinatingNodeId),
169+
MLForwardAction.NAME,
170+
deployModelDoneMessage,
171+
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
172+
);
173+
}
171174
}, e -> {
172175
MLForwardInput mlForwardInput = MLForwardInput
173176
.builder()
@@ -179,13 +182,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
179182
.build();
180183
MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);
181184

182-
transportService
183-
.sendRequest(
184-
getNodeById(coordinatingNodeId),
185-
MLForwardAction.NAME,
186-
deployModelDoneMessage,
187-
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
188-
);
185+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
186+
transportService
187+
.sendRequest(
188+
getNodeById(coordinatingNodeId),
189+
MLForwardAction.NAME,
190+
deployModelDoneMessage,
191+
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
192+
);
193+
}
189194
})
190195
);
191196

0 commit comments

Comments
 (0)