3131package com .amazon .opendistroforelasticsearch .security .transport ;
3232
3333import java .net .InetSocketAddress ;
34+ import java .lang .reflect .Method ;
3435import java .security .cert .X509Certificate ;
3536import java .util .Objects ;
3637import java .util .UUID ;
@@ -93,21 +94,21 @@ public class OpenDistroSecurityRequestHandler<T extends TransportRequest> extend
9394 @ Override
9495 protected void messageReceivedDecorate (final T request , final TransportRequestHandler <T > handler ,
9596 final TransportChannel transportChannel , Task task ) throws Exception {
96-
97+
9798 String resolvedActionClass = request .getClass ().getSimpleName ();
98-
99+
99100 if (request instanceof BulkShardRequest ) {
100101 if (((BulkShardRequest ) request ).items ().length == 1 ) {
101102 resolvedActionClass = ((BulkShardRequest ) request ).items ()[0 ].request ().getClass ().getSimpleName ();
102103 }
103104 }
104-
105+
105106 if (request instanceof ConcreteShardRequest ) {
106107 resolvedActionClass = ((ConcreteShardRequest ) request ).getRequest ().getClass ().getSimpleName ();
107108 }
108-
109+
109110 String initialActionClassValue = getThreadContext ().getHeader (ConfigConstants .OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER );
110-
111+
111112 final ThreadContext .StoredContext sgContext = getThreadContext ().newStoredContext (false );
112113
113114 final String originHeader = getThreadContext ().getHeader (ConfigConstants .OPENDISTRO_SECURITY_ORIGIN_HEADER );
@@ -118,20 +119,29 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
118119
119120 try {
120121
121- if (transportChannel .getChannelType () == null ) {
122- throw new RuntimeException ("Can not determine channel type (null)" );
123- }
122+ if (transportChannel .getChannelType () == null ) {
123+ throw new RuntimeException ("Can not determine channel type (null)" );
124+ }
125+
126+ String channelType = transportChannel .getChannelType ();
127+
128+ if (!channelType .equals ("direct" ) && !channelType .equals ("netty" )) {
129+ Class wrappedChannelCls = transportChannel .getClass ();
124130
125- if (!transportChannel .getChannelType ().equals ("direct" ) && !transportChannel .getChannelType ().equals ("netty" )
126- && !transportChannel .getChannelType ().equals ("PerformanceAnalyzerTransportChannelType" )) {
127- throw new RuntimeException ("Unknown channel type " +transportChannel .getChannelType ());
128- }
131+ try {
132+ Method getInnerChannel = wrappedChannelCls .getMethod ("getInnerChannel" , null );
133+ TransportChannel innerChannel = (TransportChannel )(getInnerChannel .invoke (transportChannel ));
134+ channelType = innerChannel .getChannelType ();
135+ } catch (NoSuchMethodException ex ) {
136+ throw new RuntimeException ("Unknown channel type " + channelType + " does not implement getInnerChannel method." );
137+ }
138+ }
129139
130- getThreadContext ().putTransient (ConfigConstants .OPENDISTRO_SECURITY_CHANNEL_TYPE , transportChannel . getChannelType () );
131- getThreadContext ().putTransient (ConfigConstants .OPENDISTRO_SECURITY_ACTION_NAME , task .getAction ());
140+ getThreadContext ().putTransient (ConfigConstants .OPENDISTRO_SECURITY_CHANNEL_TYPE , channelType );
141+ getThreadContext ().putTransient (ConfigConstants .OPENDISTRO_SECURITY_ACTION_NAME , task .getAction ());
132142
133143 //bypass non-netty requests
134- if (transportChannel . getChannelType () .equals ("direct" )) {
144+ if (channelType .equals ("direct" )) {
135145 final String userHeader = getThreadContext ().getHeader (ConfigConstants .OPENDISTRO_SECURITY_USER_HEADER );
136146
137147 if (!Strings .isNullOrEmpty (userHeader )) {
@@ -147,7 +157,7 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
147157 if (actionTrace .isTraceEnabled ()) {
148158 getThreadContext ().putHeader ("_opendistro_security_trace" +System .currentTimeMillis ()+"#" +UUID .randomUUID ().toString (), Thread .currentThread ().getName ()+" DIR -> " +transportChannel .getChannelType ()+" " +getThreadContext ().getHeaders ());
149159 }
150-
160+
151161 putInitialActionClassHeader (initialActionClassValue , resolvedActionClass );
152162
153163 super .messageReceivedDecorate (request , handler , transportChannel , task );
@@ -165,9 +175,9 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
165175 auditLog .logMissingPrivileges (task .getAction (), request , task );
166176 log .error ("Internal or shard requests (" +task .getAction ()+") not allowed from a non-server node for transport type " +transportChannel .getChannelType ());
167177 transportChannel .sendResponse (new ElasticsearchSecurityException (
168- "Internal or shard requests not allowed from a non-server node for transport type " +transportChannel .getChannelType ()));
178+ "Internal or shard requests not allowed from a non-server node for transport type " +transportChannel .getChannelType ()));
169179 return ;
170- }
180+ }
171181
172182
173183 String principal = null ;
@@ -223,34 +233,34 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
223233
224234 User user ;
225235 //try {
226- if ((user = backendRegistry .authenticate (request , principal , task , task .getAction ())) == null ) {
227- org .apache .logging .log4j .ThreadContext .remove ("user" );
228-
229- if (task .getAction ().equals (WhoAmIAction .NAME )) {
230- super .messageReceivedDecorate (request , handler , transportChannel , task );
231- return ;
232- }
233-
234- if (task .getAction ().equals ("cluster:monitor/nodes/liveness" )
235- || task .getAction ().equals ("internal:transport/handshake" )) {
236- super .messageReceivedDecorate (request , handler , transportChannel , task );
237- return ;
238- }
239-
240-
241- log .error ("Cannot authenticate {} for {}" , getThreadContext ().getTransient (ConfigConstants .OPENDISTRO_SECURITY_USER ), task .getAction ());
242- transportChannel .sendResponse (new ElasticsearchSecurityException ("Cannot authenticate " +getThreadContext ().getTransient (ConfigConstants .OPENDISTRO_SECURITY_USER )));
236+ if ((user = backendRegistry .authenticate (request , principal , task , task .getAction ())) == null ) {
237+ org .apache .logging .log4j .ThreadContext .remove ("user" );
238+
239+ if (task .getAction ().equals (WhoAmIAction .NAME )) {
240+ super .messageReceivedDecorate (request , handler , transportChannel , task );
243241 return ;
244- } else {
245- // make it possible to filter logs by username
246- org .apache .logging .log4j .ThreadContext .put ("user" , user .getName ());
247242 }
243+
244+ if (task .getAction ().equals ("cluster:monitor/nodes/liveness" )
245+ || task .getAction ().equals ("internal:transport/handshake" )) {
246+ super .messageReceivedDecorate (request , handler , transportChannel , task );
247+ return ;
248+ }
249+
250+
251+ log .error ("Cannot authenticate {} for {}" , getThreadContext ().getTransient (ConfigConstants .OPENDISTRO_SECURITY_USER ), task .getAction ());
252+ transportChannel .sendResponse (new ElasticsearchSecurityException ("Cannot authenticate " +getThreadContext ().getTransient (ConfigConstants .OPENDISTRO_SECURITY_USER )));
253+ return ;
254+ } else {
255+ // make it possible to filter logs by username
256+ org .apache .logging .log4j .ThreadContext .put ("user" , user .getName ());
257+ }
248258 //} catch (Exception e) {
249- // log.error("Error authentication transport user "+e, e);
250- //auditLog.logFailedLogin(principal, false, null, request);
251- //transportChannel.sendResponse(ExceptionsHelper.convertToElastic(e));
252- //return;
253- //}
259+ // log.error("Error authentication transport user "+e, e);
260+ //auditLog.logFailedLogin(principal, false, null, request);
261+ //transportChannel.sendResponse(ExceptionsHelper.convertToElastic(e));
262+ //return;
263+ //}
254264
255265 getThreadContext ().putTransient (ConfigConstants .OPENDISTRO_SECURITY_USER , user );
256266 TransportAddress originalRemoteAddress = request .remoteAddress ();
@@ -268,9 +278,9 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa
268278 getThreadContext ().putHeader ("_opendistro_security_trace" +System .currentTimeMillis ()+"#" +UUID .randomUUID ().toString (), Thread .currentThread ().getName ()+" NETTI -> " +transportChannel .getChannelType ()+" " +getThreadContext ().getHeaders ().entrySet ().stream ().filter (p ->!p .getKey ().startsWith ("_opendistro_security_trace" )).collect (Collectors .toMap (p -> p .getKey (), p -> p .getValue ())));
269279 }
270280
271-
281+
272282 putInitialActionClassHeader (initialActionClassValue , resolvedActionClass );
273-
283+
274284 super .messageReceivedDecorate (request , handler , transportChannel , task );
275285 }
276286 } finally {
0 commit comments