5
5
6
6
package org .opensearch .ml .action .prediction ;
7
7
8
+ import org .opensearch .OpenSearchStatusException ;
8
9
import org .opensearch .action .ActionListener ;
9
10
import org .opensearch .action .ActionRequest ;
10
11
import org .opensearch .action .support .ActionFilters ;
11
12
import org .opensearch .action .support .HandledTransportAction ;
12
13
import org .opensearch .client .Client ;
13
- import org .opensearch .cluster . service . ClusterService ;
14
+ import org .opensearch .common . breaker . CircuitBreakingException ;
14
15
import org .opensearch .common .inject .Inject ;
15
16
import org .opensearch .common .util .concurrent .ThreadContext ;
16
17
import org .opensearch .commons .authuser .User ;
17
- import org .opensearch .core .xcontent .NamedXContentRegistry ;
18
18
import org .opensearch .ml .common .FunctionName ;
19
19
import org .opensearch .ml .common .MLModel ;
20
+ import org .opensearch .ml .common .exception .MLResourceNotFoundException ;
20
21
import org .opensearch .ml .common .exception .MLValidationException ;
21
22
import org .opensearch .ml .common .transport .MLTaskResponse ;
22
23
import org .opensearch .ml .common .transport .prediction .MLPredictionTaskAction ;
27
28
import org .opensearch .ml .task .MLPredictTaskRunner ;
28
29
import org .opensearch .ml .task .MLTaskRunner ;
29
30
import org .opensearch .ml .utils .RestActionUtils ;
31
+ import org .opensearch .rest .RestStatus ;
30
32
import org .opensearch .tasks .Task ;
31
33
import org .opensearch .transport .TransportService ;
32
34
@@ -43,10 +45,6 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action
43
45
44
46
Client client ;
45
47
46
- ClusterService clusterService ;
47
-
48
- NamedXContentRegistry xContentRegistry ;
49
-
50
48
MLModelManager mlModelManager ;
51
49
52
50
ModelAccessControlHelper modelAccessControlHelper ;
@@ -57,19 +55,15 @@ public TransportPredictionTaskAction(
57
55
ActionFilters actionFilters ,
58
56
MLPredictTaskRunner mlPredictTaskRunner ,
59
57
MLModelCacheHelper modelCacheHelper ,
60
- ClusterService clusterService ,
61
58
Client client ,
62
- NamedXContentRegistry xContentRegistry ,
63
59
MLModelManager mlModelManager ,
64
60
ModelAccessControlHelper modelAccessControlHelper
65
61
) {
66
62
super (MLPredictionTaskAction .NAME , transportService , actionFilters , MLPredictionTaskRequest ::new );
67
63
this .mlPredictTaskRunner = mlPredictTaskRunner ;
68
64
this .transportService = transportService ;
69
65
this .modelCacheHelper = modelCacheHelper ;
70
- this .clusterService = clusterService ;
71
66
this .client = client ;
72
- this .xContentRegistry = xContentRegistry ;
73
67
this .mlModelManager = mlModelManager ;
74
68
this .modelAccessControlHelper = modelAccessControlHelper ;
75
69
}
@@ -108,7 +102,27 @@ public void onResponse(MLModel mlModel) {
108
102
}
109
103
}, e -> {
110
104
log .error ("Failed to Validate Access for ModelId " + modelId , e );
111
- wrappedListener .onFailure (e );
105
+ if (e instanceof OpenSearchStatusException ) {
106
+ wrappedListener
107
+ .onFailure (
108
+ new OpenSearchStatusException (
109
+ e .getMessage (),
110
+ RestStatus .fromCode (((OpenSearchStatusException ) e ).status ().getStatus ())
111
+ )
112
+ );
113
+ } else if (e instanceof MLResourceNotFoundException ) {
114
+ wrappedListener .onFailure (new OpenSearchStatusException (e .getMessage (), RestStatus .NOT_FOUND ));
115
+ } else if (e instanceof CircuitBreakingException ) {
116
+ wrappedListener .onFailure (e );
117
+ } else {
118
+ wrappedListener
119
+ .onFailure (
120
+ new OpenSearchStatusException (
121
+ "Failed to Validate Access for ModelId " + modelId ,
122
+ RestStatus .FORBIDDEN
123
+ )
124
+ );
125
+ }
112
126
}));
113
127
}
114
128
0 commit comments