org.brutusin.instrumentation.InstrumentatorTest.java Source code

Java tutorial

Introduction

Here is the source code for org.brutusin.instrumentation.InstrumentatorTest.java

Source

/*
 * Copyright 2014 brutusin.org
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.brutusin.instrumentation;

import static org.junit.Assert.*;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.apache.commons.io.IOUtils;
import org.brutusin.commons.Bean;
import org.junit.Test;

public class InstrumentatorTest {

    private static final String callbackId = "1";

    private static Class<?> clazz = SimpleClass.class;

    private static Class instrumentClass(Interceptor interceptor) throws Exception {
        if (Callback.getInstance(callbackId) != null) {
            Callback.removeCallback(callbackId);
        }
        Callback.registerCallback(callbackId, interceptor);
        ByteClassLoader cl = new ByteClassLoader();
        String className = clazz.getCanonicalName();
        String resourceName = className.replace('.', '/') + ".class";
        InputStream is = clazz.getClassLoader().getResourceAsStream(resourceName);
        byte[] bytes = IOUtils.toByteArray(is);
        Instrumentator instrumentator = new Instrumentator(className, bytes, interceptor, callbackId);
        byte[] newbytes = instrumentator.modifyClass();
        return cl.loadClass(className, newbytes);
    }

    public void compareStaticMethodResult(String methodName, Class[] argClasses, Object[] args) throws Exception {
        Method method = clazz.getMethod(methodName, argClasses);
        Object result = method.invoke(null, args);
        Class instrumentedClass = instrumentClass(new VoidInterceptor());
        Method method2 = instrumentedClass.getMethod(methodName, argClasses);
        Object result2 = method2.invoke(null, args);
        assertEquals(result2, result);
    }

    @Test
    public void testEqualResult() throws Exception {
        compareStaticMethodResult("sayHello", new Class[] { String.class }, new Object[] { "world" });
    }

    @Test
    public void testException() throws Exception {
        final Bean<String> executionIdBean = new Bean<String>();
        final Bean<Integer> counterBean = new Bean<Integer>();
        final Bean<Throwable> thBean = new Bean<Throwable>();
        final Bean<AssertionError> assertionBean = new Bean<AssertionError>();

        Interceptor interceptor = new VoidInterceptor() {
            @Override
            public void doOnStart(Object source, Object[] arg, String executionId) {
                try {
                    assertNull(executionIdBean.getValue());
                    assertNull(counterBean.getValue());
                    assertNull(thBean.getValue());
                    executionIdBean.setValue(executionId);
                    counterBean.setValue(1);
                } catch (AssertionError e) {
                    if (assertionBean.getValue() == null) {
                        assertionBean.setValue(e);
                    } else {
                        //assertionBean.getValue().addSuppressed(e);
                        assertionBean.setValue(e); // JDK 6 comp
                    }
                }
            }

            @Override
            public void doOnThrowableThrown(Object source, Throwable th, String executionId) {
                try {
                    assertNotNull(executionIdBean.getValue());
                    assertEquals(executionId, executionIdBean.getValue());
                    assertEquals(new Integer(1), counterBean.getValue());
                    assertNull(thBean.getValue());
                    counterBean.setValue(counterBean.getValue() + 1);
                    thBean.setValue(th);
                } catch (AssertionError e) {
                    if (assertionBean.getValue() == null) {
                        assertionBean.setValue(e);
                    } else {
                        //assertionBean.getValue().addSuppressed(e);
                        assertionBean.setValue(e); // JDK 6 comp
                    }
                }
            }

            @Override
            public void doOnThrowableUncatched(Object source, Throwable th, String executionId) {
                try {
                    assertNotNull(executionIdBean.getValue());
                    assertNotNull(thBean.getValue());
                    assertEquals(executionId, executionIdBean.getValue());
                    assertEquals(th, thBean.getValue());
                    assertEquals(new Integer(2), counterBean.getValue());
                } catch (AssertionError e) {
                    if (assertionBean.getValue() == null) {
                        assertionBean.setValue(e);
                    } else {
                        //assertionBean.getValue().addSuppressed(e);
                        assertionBean.setValue(e); // JDK 6 comp
                    }
                }
            }

            @Override
            public void doOnFinish(Object source, Object o, String executionId) {
                AssertionError e = new AssertionError();
                if (assertionBean.getValue() == null) {
                    assertionBean.setValue(e);
                } else {
                    //assertionBean.getValue().addSuppressed(e);
                    assertionBean.setValue(e); // JDK 6 comp
                }
            }
        };

        try {
            Class instrumentedClass = instrumentClass(interceptor);
            instrumentedClass.getMethod("throwHello").invoke(null);
        } catch (InvocationTargetException ite) {
            assertEquals(ite.getTargetException(), thBean.getValue());
        }
        if (assertionBean.getValue() != null) {
            throw assertionBean.getValue();
        }
        assertTrue(thBean.getValue() instanceof RuntimeException
                && thBean.getValue().getMessage().equals(SimpleClass.GREETING));
    }

    @Test
    public void testSuccess() throws Exception {
        final Bean<String> executionIdBean = new Bean<String>();
        final Bean<AssertionError> assertionBean = new Bean<AssertionError>();

        Interceptor interceptor = new VoidInterceptor() {

            @Override
            public void doOnStart(Object source, Object[] arg, String executionId) {
                try {
                    assertNull(executionIdBean.getValue());
                    executionIdBean.setValue(executionId);
                } catch (AssertionError e) {
                    if (assertionBean.getValue() == null) {
                        assertionBean.setValue(e);
                    } else {
                        //assertionBean.getValue().addSuppressed(e);
                        assertionBean.setValue(e); // JDK 6 comp
                    }
                }
            }

            @Override
            public void doOnFinish(Object source, Object o, String executionId) {
                try {
                    assertNotNull(executionIdBean.getValue());
                    assertEquals(executionId, executionIdBean.getValue());
                } catch (AssertionError e) {
                    if (assertionBean.getValue() == null) {
                        assertionBean.setValue(e);
                    } else {
                        //assertionBean.getValue().addSuppressed(e);
                        assertionBean.setValue(e); // JDK 6 comp
                    }
                }
            }

            @Override
            public void doOnThrowableThrown(Object source, Throwable th, String executionId) {
                AssertionError e = new AssertionError();
                if (assertionBean.getValue() == null) {
                    assertionBean.setValue(e);
                } else {
                    //assertionBean.getValue().addSuppressed(e);
                    assertionBean.setValue(e); // JDK 6 comp
                }
            }

            @Override
            public void doOnThrowableUncatched(Object source, Throwable th, String executionId) {
                AssertionError e = new AssertionError();
                if (assertionBean.getValue() == null) {
                    assertionBean.setValue(e);
                } else {
                    //assertionBean.getValue().addSuppressed(e);
                    assertionBean.setValue(e); // JDK 6 comp
                }
            }
        };
        Class instrumentedClass = instrumentClass(interceptor);
        instrumentedClass.getMethod("sayHello", String.class).invoke(null, "world");
        if (assertionBean.getValue() != null) {
            throw assertionBean.getValue();
        }
    }

    @Test
    public void testNested() throws Exception {
        final Bean<Integer> startCounterBean = new Bean<Integer>();
        startCounterBean.setValue(0);
        final Bean<Integer> finishCounterBean = new Bean<Integer>();
        finishCounterBean.setValue(0);
        final Bean<AssertionError> assertionBean = new Bean<AssertionError>();

        Interceptor interceptor = new VoidInterceptor() {

            @Override
            public void doOnStart(Object source, Object[] arg, String executionId) {
                startCounterBean.setValue(startCounterBean.getValue() + 1);
            }

            @Override
            public void doOnFinish(Object source, Object o, String executionId) {
                finishCounterBean.setValue(finishCounterBean.getValue() + 1);
            }

            @Override
            public void doOnThrowableThrown(Object source, Throwable th, String executionId) {
                AssertionError e = new AssertionError();
                if (assertionBean.getValue() == null) {
                    assertionBean.setValue(e);
                } else {
                    //assertionBean.getValue().addSuppressed(e);
                    assertionBean.setValue(e); // JDK 6 comp
                }
            }

            @Override
            public void doOnThrowableUncatched(Object source, Throwable th, String executionId) {
                AssertionError e = new AssertionError();
                if (assertionBean.getValue() == null) {
                    assertionBean.setValue(e);
                } else {
                    //assertionBean.getValue().addSuppressed(e);
                    assertionBean.setValue(e); // JDK 6 comp
                }
            }
        };
        Class instrumentedClass = instrumentClass(interceptor);
        instrumentedClass.getMethod("sayHelloDate", String.class).invoke(null, "world");
        if (assertionBean.getValue() != null) {
            throw assertionBean.getValue();
        }
        assertEquals(startCounterBean.getValue(), new Integer(2));
        assertEquals(finishCounterBean.getValue(), new Integer(2));
    }

    static class ByteClassLoader extends ClassLoader {

        public Class<?> loadClass(String name, byte[] byteCode) {
            return super.defineClass(name, byteCode, 0, byteCode.length);
        }
    }
}