8
8
import static org .opensearch .ml .common .MLModel .MODEL_CONTENT_FIELD ;
9
9
import static org .opensearch .ml .common .MLModel .OLD_MODEL_CONTENT_FIELD ;
10
10
11
+ import java .security .AccessController ;
12
+ import java .security .PrivilegedActionException ;
13
+ import java .security .PrivilegedExceptionAction ;
11
14
import java .util .ArrayList ;
12
15
import java .util .Collections ;
13
16
import java .util .HashSet ;
30
33
import org .opensearch .cluster .node .DiscoveryNode ;
31
34
import org .opensearch .cluster .service .ClusterService ;
32
35
import org .opensearch .common .Nullable ;
33
- import org .opensearch .common .util .concurrent .ThreadContext ;
34
36
import org .opensearch .commons .ConfigConstants ;
35
37
import org .opensearch .commons .authuser .User ;
36
38
import org .opensearch .core .action .ActionListener ;
44
46
import org .opensearch .search .fetch .subphase .FetchSourceContext ;
45
47
import org .opensearch .search .internal .InternalSearchResponse ;
46
48
49
+ import com .fasterxml .jackson .databind .JsonNode ;
50
+ import com .fasterxml .jackson .databind .ObjectMapper ;
47
51
import com .google .common .annotations .VisibleForTesting ;
48
52
49
53
import lombok .extern .log4j .Log4j2 ;
@@ -71,9 +75,12 @@ public class RestActionUtils {
71
75
public static final String PARAMETER_TOOL_NAME = "tool_name" ;
72
76
73
77
public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_" ;
74
- public static final String OPENDISTRO_SECURITY_SSL_PRINCIPAL = OPENDISTRO_SECURITY_CONFIG_PREFIX + "ssl_principal" ;
78
+
79
+ public static final String OPENDISTRO_SECURITY_USER = OPENDISTRO_SECURITY_CONFIG_PREFIX + "user" ;
75
80
76
81
static final Set <LdapName > adminDn = new HashSet <>();
82
+ static final Set <String > adminUsernames = new HashSet <String >();
83
+ static final ObjectMapper objectMapper = new ObjectMapper ();
77
84
78
85
public static String getAlgorithm (RestRequest request ) {
79
86
String algorithm = request .param (PARAMETER_ALGORITHM );
@@ -212,7 +219,7 @@ public static Optional<String> getStringParam(RestRequest request, String paramN
212
219
*/
213
220
public static User getUserContext (Client client ) {
214
221
String userStr = client .threadPool ().getThreadContext ().getTransient (ConfigConstants .OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT );
215
- logger .debug ("Filtering result by " + userStr );
222
+ logger .debug ("Current user is " + userStr );
216
223
return User .parse (userStr );
217
224
}
218
225
@@ -226,13 +233,25 @@ public static boolean isSuperAdminUser(ClusterService clusterService, Client cli
226
233
logger .debug ("{} is registered as an admin dn" , dn );
227
234
adminDn .add (new LdapName (dn ));
228
235
} catch (final InvalidNameException e ) {
229
- logger .error ("Unable to parse admin dn {}" , dn , e );
236
+ logger .debug ("Unable to parse admin dn {}" , dn , e );
237
+ adminUsernames .add (dn );
230
238
}
231
239
}
232
240
233
- ThreadContext threadContext = client .threadPool ().getThreadContext ();
234
- final String sslPrincipal = threadContext .getTransient (OPENDISTRO_SECURITY_SSL_PRINCIPAL );
235
- return isAdminDN (sslPrincipal );
241
+ Object userObject = client .threadPool ().getThreadContext ().getTransient (OPENDISTRO_SECURITY_USER );
242
+ if (userObject == null )
243
+ return false ;
244
+ try {
245
+ return AccessController .doPrivileged ((PrivilegedExceptionAction <Boolean >) () -> {
246
+ String userContext = objectMapper .writeValueAsString (userObject );
247
+ final JsonNode node = objectMapper .readTree (userContext );
248
+ final String userName = node .get ("name" ).asText ();
249
+
250
+ return isAdminDN (userName );
251
+ });
252
+ } catch (PrivilegedActionException e ) {
253
+ throw new RuntimeException (e );
254
+ }
236
255
}
237
256
238
257
private static boolean isAdminDN (String dn ) {
@@ -241,7 +260,7 @@ private static boolean isAdminDN(String dn) {
241
260
try {
242
261
return isAdminDN (new LdapName (dn ));
243
262
} catch (InvalidNameException e ) {
244
- return false ;
263
+ return adminUsernames . contains ( dn ) ;
245
264
}
246
265
}
247
266
0 commit comments