org.apache.cassandra.cql3.functions.UDFByteCodeVerifier.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.cassandra.cql3.functions.UDFByteCodeVerifier.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.cql3.functions;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

/**
 * Verifies Java UDF byte code.
 * Checks for disallowed method calls (e.g. {@code Object.finalize()}),
 * additional code in the constructor,
 * use of {@code synchronized} blocks,
 * too many methods.
 */
public final class UDFByteCodeVerifier {

    public static final String JAVA_UDF_NAME = JavaUDF.class.getName().replace('.', '/');
    public static final String OBJECT_NAME = Object.class.getName().replace('.', '/');
    public static final String CTOR_SIG = "(Lcom/datastax/driver/core/TypeCodec;[Lcom/datastax/driver/core/TypeCodec;)V";

    private final Set<String> disallowedClasses = new HashSet<>();
    private final Multimap<String, String> disallowedMethodCalls = HashMultimap.create();
    private final List<String> disallowedPackages = new ArrayList<>();

    public UDFByteCodeVerifier() {
        addDisallowedMethodCall(OBJECT_NAME, "clone");
        addDisallowedMethodCall(OBJECT_NAME, "finalize");
        addDisallowedMethodCall(OBJECT_NAME, "notify");
        addDisallowedMethodCall(OBJECT_NAME, "notifyAll");
        addDisallowedMethodCall(OBJECT_NAME, "wait");
    }

    public UDFByteCodeVerifier addDisallowedClass(String clazz) {
        disallowedClasses.add(clazz);
        return this;
    }

    public UDFByteCodeVerifier addDisallowedMethodCall(String clazz, String method) {
        disallowedMethodCalls.put(clazz, method);
        return this;
    }

    public UDFByteCodeVerifier addDisallowedPackage(String pkg) {
        disallowedPackages.add(pkg);
        return this;
    }

    public Set<String> verify(byte[] bytes) {
        Set<String> errors = new TreeSet<>(); // it's a TreeSet for unit tests
        ClassVisitor classVisitor = new ClassVisitor(Opcodes.ASM5) {
            public FieldVisitor visitField(int access, String name, String desc, String signature, Object value) {
                errors.add("field declared: " + name);
                return null;
            }

            public MethodVisitor visitMethod(int access, String name, String desc, String signature,
                    String[] exceptions) {
                if ("<init>".equals(name) && CTOR_SIG.equals(desc)) {
                    if (Opcodes.ACC_PUBLIC != access)
                        errors.add("constructor not public");
                    // allowed constructor - JavaUDF(TypeCodec returnCodec, TypeCodec[] argCodecs)
                    return new ConstructorVisitor(errors);
                }
                if ("executeImpl".equals(name) && "(ILjava/util/List;)Ljava/nio/ByteBuffer;".equals(desc)) {
                    if (Opcodes.ACC_PROTECTED != access)
                        errors.add("executeImpl not protected");
                    // the executeImpl method - ByteBuffer executeImpl(int protocolVersion, List<ByteBuffer> params)
                    return new ExecuteImplVisitor(errors);
                }
                if ("<clinit>".equals(name)) {
                    errors.add("static initializer declared");
                } else {
                    errors.add("not allowed method declared: " + name + desc);
                    return new ExecuteImplVisitor(errors);
                }
                return null;
            }

            public void visit(int version, int access, String name, String signature, String superName,
                    String[] interfaces) {
                if (!JAVA_UDF_NAME.equals(superName)) {
                    errors.add("class does not extend " + JavaUDF.class.getName());
                }
                if (access != (Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL | Opcodes.ACC_SUPER)) {
                    errors.add("class not public final");
                }
                super.visit(version, access, name, signature, superName, interfaces);
            }

            public void visitInnerClass(String name, String outerName, String innerName, int access) {
                errors.add("class declared as inner class");
                super.visitInnerClass(name, outerName, innerName, access);
            }
        };

        ClassReader classReader = new ClassReader(bytes);
        classReader.accept(classVisitor, ClassReader.SKIP_DEBUG);

        return errors;
    }

    private class ExecuteImplVisitor extends MethodVisitor {
        private final Set<String> errors;

        ExecuteImplVisitor(Set<String> errors) {
            super(Opcodes.ASM5);
            this.errors = errors;
        }

        public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
            if (disallowedClasses.contains(owner)) {
                errorDisallowed(owner, name);
            }
            Collection<String> disallowed = disallowedMethodCalls.get(owner);
            if (disallowed != null && disallowed.contains(name)) {
                errorDisallowed(owner, name);
            }
            if (!JAVA_UDF_NAME.equals(owner)) {
                for (String pkg : disallowedPackages) {
                    if (owner.startsWith(pkg))
                        errorDisallowed(owner, name);
                }
            }
            super.visitMethodInsn(opcode, owner, name, desc, itf);
        }

        private void errorDisallowed(String owner, String name) {
            errors.add("call to " + owner.replace('/', '.') + '.' + name + "()");
        }

        public void visitInsn(int opcode) {
            switch (opcode) {
            case Opcodes.MONITORENTER:
            case Opcodes.MONITOREXIT:
                errors.add("use of synchronized");
                break;
            }
            super.visitInsn(opcode);
        }
    }

    private static class ConstructorVisitor extends MethodVisitor {
        private final Set<String> errors;

        ConstructorVisitor(Set<String> errors) {
            super(Opcodes.ASM5);
            this.errors = errors;
        }

        public void visitInvokeDynamicInsn(String name, String desc, Handle bsm, Object... bsmArgs) {
            errors.add("Use of invalid method instruction in constructor");
            super.visitInvokeDynamicInsn(name, desc, bsm, bsmArgs);
        }

        public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
            if (!(Opcodes.INVOKESPECIAL == opcode && JAVA_UDF_NAME.equals(owner) && "<init>".equals(name)
                    && CTOR_SIG.equals(desc))) {
                errors.add("initializer declared");
            }
            super.visitMethodInsn(opcode, owner, name, desc, itf);
        }

        public void visitInsn(int opcode) {
            if (Opcodes.RETURN != opcode) {
                errors.add("initializer declared");
            }
            super.visitInsn(opcode);
        }
    }
}