org.apache.commons.weaver.privilizer.BlueprintingVisitor.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.commons.weaver.privilizer.BlueprintingVisitor.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.commons.weaver.privilizer;

import java.io.InputStream;
import java.lang.invoke.LambdaMetafactory;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.AdviceAdapter;
import org.objectweb.asm.commons.GeneratorAdapter;
import org.objectweb.asm.commons.Method;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.MethodNode;

/**
 * {@link ClassVisitor} to import so-called "blueprint methods".
 */
class BlueprintingVisitor extends Privilizer.PrivilizerClassVisitor {
    static class TypeInfo {
        final int access;
        final String superName;
        final Map<String, FieldNode> fields;
        final Map<Method, MethodNode> methods;

        TypeInfo(int access, String superName, Map<String, FieldNode> fields, Map<Method, MethodNode> methods) {
            super();
            this.access = access;
            this.superName = superName;
            this.fields = fields;
            this.methods = methods;
        }
    }

    private static final Type LAMBDA_METAFACTORY = Type.getType(LambdaMetafactory.class);

    private static Pair<Type, Method> methodKey(String owner, String name, String desc) {
        return Pair.of(Type.getObjectType(owner), new Method(name, desc));
    }

    private final Set<Type> blueprintTypes = new HashSet<>();
    private final Map<Pair<Type, Method>, MethodNode> blueprintRegistry = new HashMap<>();

    private final Map<Pair<Type, Method>, String> importedMethods = new HashMap<>();

    private final Map<Type, TypeInfo> typeInfoCache = new HashMap<>();
    private final Map<Pair<Type, String>, FieldAccess> fieldAccessMap = new HashMap<>();

    private final ClassVisitor nextVisitor;

    /**
     * Create a new {@link BlueprintingVisitor}.
     * @param privilizer owner
     * @param nextVisitor wrapped
     * @param config annotation
     */
    BlueprintingVisitor(@SuppressWarnings("PMD.UnusedFormalParameter") final Privilizer privilizer, //false positive
            final ClassVisitor nextVisitor, final Privilizing config) {
        privilizer.super(new ClassNode(Privilizer.ASM_VERSION));
        this.nextVisitor = nextVisitor;

        // load up blueprint methods:
        for (final Privilizing.CallTo callTo : config.value()) {
            final Type blueprintType = Type.getType(callTo.value());
            blueprintTypes.add(blueprintType);

            final Set<String> methodNames = new HashSet<>(Arrays.asList(callTo.methods()));

            typeInfo(blueprintType).methods.entrySet().stream()
                    .filter(e -> methodNames.isEmpty() || methodNames.contains(e.getKey().getName()))
                    .forEach(e -> blueprintRegistry.put(Pair.of(blueprintType, e.getKey()), e.getValue()));
        }
    }

    private TypeInfo typeInfo(Type type) {
        return typeInfoCache.computeIfAbsent(type, k -> {
            final ClassNode cn = read(k.getClassName());

            return new TypeInfo(cn.access, cn.superName,
                    cn.fields.stream().collect(Collectors.toMap(f -> f.name, Function.identity())),
                    cn.methods.stream()
                            .collect(Collectors.toMap(m -> new Method(m.name, m.desc), Function.identity())));
        });
    }

    private ClassNode read(final String className) {
        final ClassNode result = new ClassNode(Privilizer.ASM_VERSION);
        try (InputStream bytecode = privilizer().env.getClassfile(className).getInputStream()) {
            new ClassReader(bytecode).accept(result, ClassReader.SKIP_DEBUG | ClassReader.EXPAND_FRAMES);
        } catch (final Exception e) {
            throw new IllegalStateException(e);
        }
        return result;
    }

    @Override
    @SuppressWarnings("PMD.UseVarargs") //overridden method
    public void visit(final int version, final int access, final String name, final String signature,
            final String superName, final String[] interfaces) {
        Validate.isTrue(!blueprintTypes.contains(Type.getObjectType(name)),
                "Class %s cannot declare itself as a blueprint!", name);
        super.visit(version, access, name, signature, superName, interfaces);
    }

    @Override
    @SuppressWarnings("PMD.UseVarargs") //overridden method
    public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature,
            final String[] exceptions) {
        final MethodVisitor toWrap = super.visitMethod(access, name, desc, signature, exceptions);
        return new MethodInvocationHandler(toWrap) {
            @Override
            boolean shouldImport(final Pair<Type, Method> methodKey) {
                return blueprintRegistry.containsKey(methodKey);
            }
        };
    }

    private String importMethod(final Pair<Type, Method> key) {
        if (importedMethods.containsKey(key)) {
            return importedMethods.get(key);
        }
        final String result = new StringBuilder(key.getLeft().getInternalName().replace('/', '_')).append("$$")
                .append(key.getRight().getName()).toString();
        importedMethods.put(key, result);
        privilizer().env.debug("importing %s#%s as %s", key.getLeft().getClassName(), key.getRight(), result);
        final int access = Opcodes.ACC_PRIVATE + Opcodes.ACC_STATIC + Opcodes.ACC_SYNTHETIC;

        final MethodNode source = typeInfo(key.getLeft()).methods.get(key.getRight());

        final String[] exceptions = source.exceptions.toArray(ArrayUtils.EMPTY_STRING_ARRAY);

        // non-public fields accessed
        final Set<FieldAccess> fieldAccesses = new LinkedHashSet<>();

        source.accept(new MethodVisitor(Privilizer.ASM_VERSION) {
            @Override
            public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
                final FieldAccess fieldAccess = fieldAccess(Type.getObjectType(owner), name);

                super.visitFieldInsn(opcode, owner, name, desc);
                if (!Modifier.isPublic(fieldAccess.access)) {
                    fieldAccesses.add(fieldAccess);
                }
            }
        });

        final MethodNode withAccessibleAdvice = new MethodNode(access, result, source.desc, source.signature,
                exceptions);

        // spider own methods:
        MethodVisitor mv = new NestedMethodInvocationHandler(withAccessibleAdvice, key); //NOPMD

        if (!fieldAccesses.isEmpty()) {
            mv = new AccessibleAdvisor(mv, access, result, source.desc, new ArrayList<>(fieldAccesses));
        }
        source.accept(mv);

        // private can only be called by other privileged methods, so no need to mark as privileged
        if (!Modifier.isPrivate(source.access)) {
            withAccessibleAdvice.visitAnnotation(Type.getType(Privileged.class).getDescriptor(), false).visitEnd();
        }
        withAccessibleAdvice.accept(this.cv);

        return result;
    }

    private FieldAccess fieldAccess(final Type owner, final String name) {
        return fieldAccessMap.computeIfAbsent(Pair.of(owner, name), k -> {
            final FieldNode fieldNode = typeInfo(k.getLeft()).fields.get(k.getRight());
            Validate.validState(fieldNode != null, "Could not locate %s.%s", k.getLeft().getClassName(),
                    k.getRight());
            return new FieldAccess(fieldNode.access, k.getLeft(), fieldNode.name, Type.getType(fieldNode.desc));
        });
    }

    @Override
    public void visitEnd() {
        super.visitEnd();
        ((ClassNode) cv).accept(nextVisitor);
    }

    private abstract class MethodInvocationHandler extends MethodVisitor {
        MethodInvocationHandler(final MethodVisitor mvr) {
            super(Privilizer.ASM_VERSION, mvr);
        }

        @Override
        public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc,
                final boolean itf) {
            if (opcode == Opcodes.INVOKESTATIC) {
                final Pair<Type, Method> methodKey = methodKey(owner, name, desc);
                if (shouldImport(methodKey)) {
                    final String importedName = importMethod(methodKey);
                    super.visitMethodInsn(opcode, className, importedName, desc, itf);
                    return;
                }
            }
            visitNonImportedMethodInsn(opcode, owner, name, desc, itf);
        }

        protected void visitNonImportedMethodInsn(final int opcode, final String owner, final String name,
                final String desc, final boolean itf) {
            super.visitMethodInsn(opcode, owner, name, desc, itf);
        }

        @Override
        public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle,
                Object... bootstrapMethodArguments) {

            if (isLambda(bootstrapMethodHandle)) {
                Object[] args = bootstrapMethodArguments;

                Handle handle = null;

                for (int i = 0; i < args.length; i++) {
                    if (bootstrapMethodArguments[i] instanceof Handle) {
                        if (handle != null) {
                            // we don't know what to do with multiple handles; skip the whole thing:
                            args = bootstrapMethodArguments;
                            break;
                        }
                        handle = (Handle) args[i];

                        if (handle.getTag() == Opcodes.H_INVOKESTATIC) {
                            final Pair<Type, Method> methodKey = methodKey(handle.getOwner(), handle.getName(),
                                    handle.getDesc());

                            if (shouldImport(methodKey)) {
                                final String importedName = importMethod(methodKey);
                                args = bootstrapMethodArguments.clone();
                                args[i] = new Handle(handle.getTag(), className, importedName, handle.getDesc(),
                                        false);
                            }
                        }
                    }
                }
                if (handle != null) {
                    if (args == bootstrapMethodArguments) {
                        validateLambda(handle);
                    } else {
                        super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, args);
                        return;
                    }
                }
            }
            super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
        }

        protected void validateLambda(Handle handle) {
        }

        abstract boolean shouldImport(Pair<Type, Method> methodKey);

        private boolean isLambda(Handle handle) {
            return handle.getTag() == Opcodes.H_INVOKESTATIC
                    && LAMBDA_METAFACTORY.getInternalName().equals(handle.getOwner())
                    && "metafactory".equals(handle.getName());
        }
    }

    class NestedMethodInvocationHandler extends MethodInvocationHandler {
        final Pair<Type, Method> methodKey;
        final Type owner;

        NestedMethodInvocationHandler(final MethodVisitor mvr, final Pair<Type, Method> methodKey) {
            super(mvr);
            this.methodKey = methodKey;
            this.owner = methodKey.getLeft();
        }

        @Override
        protected void visitNonImportedMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
            final Type ownerType = Type.getObjectType(owner);
            final Method m = new Method(name, desc);

            if (isAccessible(ownerType) && isAccessible(ownerType, m)) {
                super.visitNonImportedMethodInsn(opcode, owner, name, desc, itf);
            } else {
                throw new IllegalStateException(
                        String.format("Blueprint method %s.%s calls inaccessible method %s.%s", this.owner,
                                methodKey.getRight(), owner, m));
            }
        }

        @Override
        protected void validateLambda(Handle handle) {
            super.validateLambda(handle);
            final Type ownerType = Type.getObjectType(handle.getOwner());
            final Method m = new Method(handle.getName(), handle.getDesc());

            if (!(isAccessible(ownerType) && isAccessible(ownerType, m))) {
                throw new IllegalStateException(
                        String.format("Blueprint method %s.%s utilizes inaccessible method reference %s::%s", owner,
                                methodKey.getRight(), handle.getOwner(), m));
            }
        }

        @Override
        boolean shouldImport(final Pair<Type, Method> methodKey) {
            // call anything called within a class hierarchy:
            final Type called = methodKey.getLeft();
            // "I prefer the short cut":
            if (called.equals(owner)) {
                return true;
            }
            try {
                final Class<?> inner = load(called);
                final Class<?> outer = load(owner);
                return inner.isAssignableFrom(outer);
            } catch (final ClassNotFoundException e) {
                return false;
            }
        }

        private Class<?> load(final Type type) throws ClassNotFoundException {
            return privilizer().env.classLoader.loadClass(type.getClassName());
        }

        private boolean isAccessible(Type type) {
            final TypeInfo typeInfo = typeInfo(type);
            return isAccessible(type, typeInfo.access);
        }

        private boolean isAccessible(Type type, Method m) {
            Type t = type;
            while (t != null) {
                final TypeInfo typeInfo = typeInfo(t);
                final MethodNode methodNode = typeInfo.methods.get(m);
                if (methodNode == null) {
                    t = Optional.ofNullable(typeInfo.superName).map(Type::getObjectType).orElse(null);
                    continue;
                }
                return isAccessible(type, methodNode.access);
            }
            throw new IllegalStateException(String.format("Cannot find method %s.%s", type, m));
        }

        private boolean isAccessible(Type type, int access) {
            if (Modifier.isPublic(access)) {
                return true;
            }
            if (Modifier.isProtected(access) || Modifier.isPrivate(access)) {
                return false;
            }
            return Stream.of(target, type).map(Type::getInternalName)
                    .map(n -> StringUtils.substringBeforeLast(n, "/")).distinct().count() == 1;
        }
    }

    /**
     * For every non-public referenced field of an imported method, replaces with reflective calls. Additionally, for
     * every such field that is not accessible, sets the field's accessibility and clears it as the method exits.
     */
    private class AccessibleAdvisor extends AdviceAdapter {
        final Type bitSetType = Type.getType(BitSet.class);
        final Type classType = Type.getType(Class.class);
        final Type fieldType = Type.getType(java.lang.reflect.Field.class);
        final Type fieldArrayType = Type.getType(java.lang.reflect.Field[].class);
        final Type stringType = Type.getType(String.class);

        final List<FieldAccess> fieldAccesses;
        final Label begin = new Label();
        int localFieldArray;
        int bitSet;
        int fieldCounter;

        AccessibleAdvisor(final MethodVisitor mvr, final int access, final String name, final String desc,
                final List<FieldAccess> fieldAccesses) {
            super(Privilizer.ASM_VERSION, mvr, access, name, desc);
            this.fieldAccesses = fieldAccesses;
        }

        @Override
        protected void onMethodEnter() {
            localFieldArray = newLocal(fieldArrayType);
            bitSet = newLocal(bitSetType);
            fieldCounter = newLocal(Type.INT_TYPE);

            // create localFieldArray
            push(fieldAccesses.size());
            newArray(fieldArrayType.getElementType());
            storeLocal(localFieldArray);

            // create bitSet
            newInstance(bitSetType);
            dup();
            push(fieldAccesses.size());
            invokeConstructor(bitSetType, Method.getMethod("void <init>(int)"));
            storeLocal(bitSet);

            // populate localFieldArray
            push(0);
            storeLocal(fieldCounter);
            for (final FieldAccess access : fieldAccesses) {
                prehandle(access);
                iinc(fieldCounter, 1);
            }
            mark(begin);
        }

        private void prehandle(final FieldAccess access) {
            // push owner.class literal
            visitLdcInsn(access.owner);
            push(access.name);
            final Label next = new Label();
            invokeVirtual(classType, new Method("getDeclaredField", fieldType, new Type[] { stringType }));

            dup();
            // store the field at localFieldArray[fieldCounter]:
            loadLocal(localFieldArray);
            swap();
            loadLocal(fieldCounter);
            swap();
            arrayStore(fieldArrayType.getElementType());

            dup();
            invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("boolean isAccessible()"));

            final Label setAccessible = new Label();
            // if false, setAccessible:
            ifZCmp(EQ, setAccessible);

            // else pop field instance
            pop();
            // and record that he was already accessible:
            loadLocal(bitSet);
            loadLocal(fieldCounter);
            invokeVirtual(bitSetType, Method.getMethod("void set(int)"));
            goTo(next);

            mark(setAccessible);
            push(true);
            invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));

            mark(next);
        }

        @Override
        public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
            final Pair<Type, String> key = Pair.of(Type.getObjectType(owner), name);
            final FieldAccess fieldAccess = fieldAccessMap.get(key);
            Validate.isTrue(fieldAccesses.contains(fieldAccess), "Cannot find field %s", key);
            final int fieldIndex = fieldAccesses.indexOf(fieldAccess);
            visitInsn(NOP);
            loadLocal(localFieldArray);
            push(fieldIndex);
            arrayLoad(fieldArrayType.getElementType());
            checkCast(fieldType);

            final Method access;
            if (opcode == PUTSTATIC) {
                // value should have been at top of stack on entry; position the field under the value:
                swap();
                // add null object for static field deref and swap under value:
                push((String) null);
                swap();
                if (fieldAccess.type.getSort() < Type.ARRAY) {
                    // box value:
                    valueOf(fieldAccess.type);
                }
                access = Method.getMethod("void set(Object, Object)");
            } else {
                access = Method.getMethod("Object get(Object)");
                // add null object for static field deref:
                push((String) null);
            }

            invokeVirtual(fieldType, access);

            if (opcode == GETSTATIC) {
                checkCast(Privilizer.wrap(fieldAccess.type));
                if (fieldAccess.type.getSort() < Type.ARRAY) {
                    unbox(fieldAccess.type);
                }
            }
        }

        @Override
        public void visitMaxs(final int maxStack, final int maxLocals) {
            // put try-finally around the whole method
            final Label fny = mark();
            // null exception type signifies finally block:
            final Type exceptionType = null;
            catchException(begin, fny, exceptionType);
            onFinally();
            throwException();
            super.visitMaxs(maxStack, maxLocals);
        }

        @Override
        protected void onMethodExit(final int opcode) {
            if (opcode != ATHROW) {
                onFinally();
            }
        }

        private void onFinally() {
            // loop over fields and return any non-null element to being inaccessible:
            push(0);
            storeLocal(fieldCounter);

            final Label test = mark();
            final Label increment = new Label();
            final Label endFinally = new Label();

            loadLocal(fieldCounter);
            push(fieldAccesses.size());
            ifCmp(Type.INT_TYPE, GeneratorAdapter.GE, endFinally);

            loadLocal(bitSet);
            loadLocal(fieldCounter);
            invokeVirtual(bitSetType, Method.getMethod("boolean get(int)"));

            // if true, increment:
            ifZCmp(NE, increment);

            loadLocal(localFieldArray);
            loadLocal(fieldCounter);
            arrayLoad(fieldArrayType.getElementType());
            push(false);
            invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));

            mark(increment);
            iinc(fieldCounter, 1);
            goTo(test);
            mark(endFinally);
        }
    }
}