com.github.mbenson.privileged.weaver.PrivilegedMethodWeaver.java Source code

Java tutorial

Introduction

Here is the source code for com.github.mbenson.privileged.weaver.PrivilegedMethodWeaver.java

Source

/*
 *  Copyright the original author or authors.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package com.github.mbenson.privileged.weaver;

import java.io.IOException;
import java.lang.reflect.Modifier;
import java.nio.charset.Charset;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtField;
import javassist.CtMethod;
import javassist.CtNewConstructor;
import javassist.CtNewMethod;
import javassist.CtPrimitiveType;
import javassist.NotFoundException;
import javassist.bytecode.Descriptor;

import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.text.StrBuilder;

import com.github.mbenson.privileged.Privileged;

/**
 * Handles weaving of methods annotated with {@link Privileged}.
 */
public abstract class PrivilegedMethodWeaver<SELF extends PrivilegedMethodWeaver<SELF>> {
    public interface ClassFileWriter {
        void write(CtClass type) throws CannotCompileException, IOException;
    }

    public interface Log {
        void debug(String message);

        void verbose(String message);

        void error(String message);

        void info(String message);

        void warn(String message);
    }

    /**
     * Weaving policy: when to use {@link PrivilegedAction}s.
     */
    public enum Policy {
        /**
         * Disables weaving.
         */
        NEVER,

        /**
         * Weaves such that the check for an active {@link SecurityManager} is
         * done once only.
         */
        ON_INIT(generateName("hasSecurityManager")),

        /**
         * Weaves such that the check for an active {@link SecurityManager} is
         * done for each {@link Privileged} method execution.
         */
        DYNAMIC(HAS_SECURITY_MANAGER_CONDITION),

        /**
         * Weaves such that {@link Privileged} methods are always executed as
         * such.
         */
        ALWAYS;

        private final String condition;

        private Policy() {
            this(null);
        }

        private Policy(String condition) {
            this.condition = condition;
        }

        private boolean isConditional() {
            return condition != null;
        }
    }

    protected static final String POLICY_NAME = "policyName";

    private static final String ACTION_SUFFIX = "_ACTION";

    private static final String GENERATE_NAME = "__privileged_%s";
    private static final String HAS_SECURITY_MANAGER_CONDITION = "System.getSecurityManager() != null";

    protected static String generateName(String simple) {
        return String.format(GENERATE_NAME, simple);
    }

    protected static String toString(byte[] b) {
        return b == null ? null : new String(b, Charset.forName("UTF-8"));
    }

    protected final Policy policy;

    protected final ClassPool classPool;

    private boolean settingsReported;

    private Log log = new Log() {
        final Logger logger = Logger.getLogger(PrivilegedMethodWeaver.class.getName());

        @Override
        public void debug(String message) {
            logger.finer(message);
        }

        @Override
        public void verbose(String message) {
            logger.fine(message);
        }

        @Override
        public void error(String message) {
            logger.severe(message);
        }

        @Override
        public void info(String message) {
            logger.info(message);
        }

        @Override
        public void warn(String message) {
            logger.warning(message);
        }

    };

    private static final Comparator<CtMethod> CTMETHOD_COMPARATOR = new Comparator<CtMethod>() {

        @Override
        public int compare(CtMethod arg0, CtMethod arg1) {
            if (ObjectUtils.equals(arg0, arg1)) {
                return 0;
            }
            if (arg0 == null) {
                return -1;
            }
            if (arg1 == null) {
                return 1;
            }
            final int result = ObjectUtils.compare(arg0.getName(), arg1.getName());
            return result == 0 ? ObjectUtils.compare(arg0.getSignature(), arg1.getSignature()) : result;
        }
    };

    private static Set<CtMethod> getPrivilegedMethods(CtClass type) throws ClassNotFoundException {
        final TreeSet<CtMethod> result = new TreeSet<CtMethod>(CTMETHOD_COMPARATOR);
        for (final CtMethod m : type.getDeclaredMethods()) {
            if (Modifier.isAbstract(m.getModifiers()) || m.getAnnotation(Privileged.class) == null) {
                continue;
            }
            result.add(m);
        }
        return result;
    }

    public PrivilegedMethodWeaver(ClassPool classPool) {
        this(Policy.DYNAMIC, classPool);
    }

    public PrivilegedMethodWeaver(Policy policy, ClassPool classPool) {
        this.policy = Validate.notNull(policy, "policy");
        this.classPool = Validate.notNull(classPool, "classPool");
    }

    public SELF loggingTo(Log log) {
        this.log = Validate.notNull(log);
        settingsReported = false;
        @SuppressWarnings("unchecked")
        final SELF self = (SELF) this;
        return self;
    }

    /**
     * Weave the specified class.
     * 
     * @param type
     * @return whether any work was done
     * @throws NotFoundException
     * @throws IOException
     * @throws CannotCompileException
     * @throws ClassNotFoundException
     */
    public boolean weave(CtClass type)
            throws NotFoundException, IOException, CannotCompileException, ClassNotFoundException {
        reportSettings();
        final String policyName = generateName(POLICY_NAME);
        final String policyValue = toString(type.getAttribute(policyName));
        if (policyValue != null) {
            verbose("%s already woven with policy %s", type.getName(), policyValue);
            if (!policy.name().equals(policyValue)) {
                throw new AlreadyWovenException(type.getName(), Policy.valueOf(policyValue));
            }
            return false;
        }
        boolean result = false;
        if (policy.compareTo(Policy.NEVER) > 0) {
            if (policy == Policy.ON_INIT) {
                debug("Initializing field %s to %s", policy.condition, HAS_SECURITY_MANAGER_CONDITION);
                type.addField(new CtField(CtClass.booleanType, policy.condition, type),
                        CtField.Initializer.byExpr(HAS_SECURITY_MANAGER_CONDITION));
            }
            for (final CtMethod m : getPrivilegedMethods(type)) {
                result |= weave(type, m);
            }
            if (result) {
                type.setAttribute(policyName, policy.name().getBytes(Charset.forName("UTF-8")));
                getClassFileWriter().write(type);
            }
        }
        log.verbose(String.format(result ? "Wove class %s" : "Nothing to do for class %s", type.getName()));
        return result;
    }

    protected void debug(String message, Object... args) {
        log.debug(String.format(message, args));
    }

    protected void verbose(String message, Object... args) {
        log.verbose(String.format(message, args));
    }

    protected void warn(String message, Object... args) {
        log.warn(String.format(message, args));
    }

    protected abstract ClassFileWriter getClassFileWriter();

    protected void info(String message, Object... args) {
        log.info(String.format(message, args));
    }

    protected boolean permitMethodWeaving(AccessLevel accessLevel) {
        return true;
    }

    private CtClass createAction(CtClass type, CtMethod impl, Class<?> iface)
            throws NotFoundException, CannotCompileException, IOException {
        final boolean exc = impl.getExceptionTypes().length > 0;

        final CtClass actionType = classPool.get(iface.getName());

        final String simpleName = generateActionClassname(impl);
        debug("Creating action type %s for method %s", simpleName, toString(impl));
        final CtClass result = type.makeNestedClass(simpleName, true);
        result.addInterface(actionType);

        final CtField owner;
        if (Modifier.isStatic(impl.getModifiers())) {
            owner = null;
        } else {
            owner = new CtField(type, generateName("owner"), result);
            owner.setModifiers(Modifier.PRIVATE | Modifier.FINAL);
            debug("Adding owner field %s to %s", owner.getName(), simpleName);
            result.addField(owner);
        }

        final List<String> propagatedParameters = new ArrayList<String>();
        int index = -1;
        for (final CtClass param : impl.getParameterTypes()) {
            final String f = String.format("arg%s", Integer.valueOf(++index));
            final CtField fld = new CtField(param, f, result);
            fld.setModifiers(Modifier.PRIVATE | Modifier.FINAL);
            debug("Copying parameter %s from %s to %s.%s", index, toString(impl), simpleName, f);
            result.addField(fld);
            propagatedParameters.add(f);
        }
        {
            final StrBuilder constructor = new StrBuilder(simpleName).append('(');
            boolean sep = false;
            final Body body = new Body();

            for (final CtField fld : result.getDeclaredFields()) {
                if (sep) {
                    constructor.append(", ");
                } else {
                    sep = true;
                }
                constructor.append(fld.getType().getName()).append(' ').append(fld.getName());
                body.appendLine("this.%1$s = %1$s;", fld.getName());
            }
            constructor.append(") ").append(body.complete());

            final String c = constructor.toString();
            debug("Creating action constructor:");
            debug(c);
            result.addConstructor(CtNewConstructor.make(c, result));
        }
        {
            final StrBuilder run = new StrBuilder("public Object run() ");
            if (exc) {
                run.append("throws Exception ");
            }
            final Body body = new Body();
            final CtClass rt = impl.getReturnType();
            final boolean isVoid = rt.equals(CtClass.voidType);
            if (!isVoid) {
                body.append("return ");
            }
            final String deref = Modifier.isStatic(impl.getModifiers()) ? type.getName() : owner.getName();
            final String call = String.format("%s.%s(%s)", deref, impl.getName(),
                    StringUtils.join(propagatedParameters, ", "));

            if (!isVoid && rt.isPrimitive()) {
                body.appendLine("%2$s.valueOf(%1$s);", call, ((CtPrimitiveType) rt).getWrapperName());
            } else {
                body.append(call).append(';').appendNewLine();

                if (isVoid) {
                    body.appendLine("return null;");
                }
            }

            run.append(body.complete());

            final String r = run.toString();
            debug("Creating run method:");
            debug(r);
            result.addMethod(CtNewMethod.make(r, result));
        }
        getClassFileWriter().write(result);
        debug("Returning action type %s", result);
        return result;
    }

    private String generateActionClassname(CtMethod m) throws NotFoundException {
        final StringBuilder b = new StringBuilder(m.getName());
        if (m.getParameterTypes().length > 0) {
            b.append("$$").append(StringUtils.strip(Descriptor.getParamDescriptor(m.getSignature()), "(;)")
                    .replace("[", "ARRAYOF_").replace('/', '_').replace(';', '$'));
        }
        return b.append(ACTION_SUFFIX).toString();
    }

    private String toString(CtMethod m) {
        return String.format("%s%s", m.getName(), m.getSignature());
    }

    private boolean weave(CtClass type, CtMethod method)
            throws ClassNotFoundException, CannotCompileException, NotFoundException, IOException {
        final AccessLevel accessLevel = AccessLevel.of(method.getModifiers());
        if (!permitMethodWeaving(accessLevel)) {
            warn("Ignoring %s method %s.%s", accessLevel, type.getName(), toString(method));
            return false;
        }
        if (AccessLevel.PACKAGE.compareTo(accessLevel) > 0) {
            warn("Possible security leak: granting privileges to %s method %s.%s", accessLevel, type.getName(),
                    toString(method));
        }
        final String implName = generateName(method.getName());

        final CtMethod impl = CtNewMethod.copy(method, implName, type, null);
        impl.setModifiers(AccessLevel.PRIVATE.merge(method.getModifiers()));
        type.addMethod(impl);
        debug("Copied %2$s %1$s.%3$s to %4$s %1$s.%5$s", type.getName(), accessLevel, toString(method),
                AccessLevel.PRIVATE, toString(impl));

        final Body body = new Body();
        if (policy.isConditional()) {
            body.startBlock("if (%s)", policy.condition);
        }

        final boolean exc = method.getExceptionTypes().length > 0;

        if (exc) {
            body.startBlock("try");
        }

        final Class<?> iface = exc ? PrivilegedExceptionAction.class : PrivilegedAction.class;
        final CtClass actionType = createAction(type, impl, iface);
        final String action = generateName("action");

        body.append("final %s %s = new %s(", iface.getName(), action, actionType.getName());
        boolean firstParam;
        if (Modifier.isStatic(impl.getModifiers())) {
            firstParam = true;
        } else {
            body.append("$0");
            firstParam = false;
        }
        for (int i = 1, sz = impl.getParameterTypes().length; i <= sz; i++) {
            if (firstParam) {
                firstParam = false;
            } else {
                body.append(", ");
            }
            body.append('$').append(Integer.toString(i));
        }
        body.appendLine(");");

        final CtClass rt = method.getReturnType();
        final boolean isVoid = rt.equals(CtClass.voidType);

        final String doPrivileged = String.format("%1$s.doPrivileged(%2$s)", AccessController.class.getName(),
                action);
        if (isVoid) {
            body.append(doPrivileged).append(';').appendNewLine();
            if (policy.isConditional()) {
                body.appendLine("return;");
            }
        } else {
            final String cast = rt.isPrimitive() ? ((CtPrimitiveType) rt).getWrapperName() : rt.getName();
            // don't worry about wrapper NPEs because we should be simply
            // passing back an autoboxed value, then unboxing again
            final String result = generateName("result");
            body.appendLine("final %1$s %3$s = (%1$s) %2$s;", cast, doPrivileged, result);
            body.append("return %s", result);
            if (rt.isPrimitive()) {
                body.append(".%sValue()", rt.getName());
            }
            body.append(';').appendNewLine();
        }

        if (exc) {
            body.endBlock();
            final String e = generateName("e");
            body.startBlock("catch (%1$s %2$s)", PrivilegedActionException.class.getName(), e).appendNewLine();

            final String wrapped = generateName("wrapped");

            body.appendLine("final Exception %1$s = %2$s.getCause();", wrapped, e);
            for (final CtClass thrown : method.getExceptionTypes()) {
                body.startBlock("if (%1$s instanceof %2$s)", wrapped, thrown.getName());
                body.appendLine("throw (%2$s) %1$s;", wrapped, thrown.getName());
                body.endBlock();
            }
            body.appendLine(
                    "throw %1$s instanceof RuntimeException ? (RuntimeException) %1$s : new RuntimeException(%1$s);",
                    wrapped);
            body.endBlock();
        }

        if (policy.isConditional()) {
            // close if block we opened before:
            body.endBlock();
            // no security manager=> just call impl:
            if (!isVoid) {
                body.append("return ");
            }
            body.appendLine("%s($$);", impl.getName());
        }

        final String block = body.complete().toString();
        debug("Setting body of %s to:\n%s", toString(method), block);
        method.setBody(block);
        return true;
    }

    private void reportSettings() {
        if (!settingsReported) {
            settingsReported = true;
            info("Weave policy == %s", policy);
        }
    }
}