com.intellij.compiler.notNullVerification.NotNullVerifyingInstrumenter.java Source code

Java tutorial

Introduction

Here is the source code for com.intellij.compiler.notNullVerification.NotNullVerifyingInstrumenter.java

Source

/*
 * Copyright 2000-2013 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intellij.compiler.notNullVerification;

import org.objectweb.asm.*;

import java.util.ArrayList;
import java.util.List;

/**
 * @author ven
 * @author Vladislav.Rassokhin
 */
public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcodes {
    private static final String NOT_NULL_CLASS_NAME = "org/jetbrains/annotations/NotNull";
    private static final String NOT_NULL_TYPE = "L" + NOT_NULL_CLASS_NAME + ";";
    private static final String SYNTHETIC_CLASS_NAME = "java/lang/Synthetic";
    private static final String SYNTHETIC_TYPE = "L" + SYNTHETIC_CLASS_NAME + ";";
    private static final String IAE_CLASS_NAME = "java/lang/IllegalArgumentException";
    private static final String ISE_CLASS_NAME = "java/lang/IllegalStateException";
    private static final String STRING_CLASS_NAME = "java/lang/String";
    private static final String CONSTRUCTOR_NAME = "<init>";
    private static final String EXCEPTION_INIT_SIGNATURE = "(L" + STRING_CLASS_NAME + ";)V";

    private static final String NULL_ARG_MESSAGE = "Argument %d for @NotNull parameter of %s.%s must not be null";
    private static final String NULL_RESULT_MESSAGE = "@NotNull method %s.%s must not return null";

    private String myClassName;
    private boolean myIsModification = false;
    private RuntimeException myPostponedError;

    public NotNullVerifyingInstrumenter(final ClassVisitor classVisitor) {
        super(Opcodes.ASM5, classVisitor);
    }

    public boolean isModification() {
        return myIsModification;
    }

    @Override
    public void visit(final int version, final int access, final String name, final String signature,
            final String superName, final String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces);
        myClassName = name;
    }

    @Override
    public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature,
            final String[] exceptions) {
        final Type[] args = Type.getArgumentTypes(desc);
        final Type returnType = Type.getReturnType(desc);
        final MethodVisitor v = cv.visitMethod(access, name, desc, signature, exceptions);
        return new MethodVisitor(Opcodes.ASM5, v) {

            private final List<Integer> myNotNullParams = new ArrayList<Integer>();
            private int mySyntheticCount = 0;
            private boolean myIsNotNull = false;
            private Label myStartGeneratedCodeLabel;

            @Override
            public AnnotationVisitor visitParameterAnnotation(final int parameter, final String anno,
                    final boolean visible) {
                final AnnotationVisitor av = mv.visitParameterAnnotation(parameter, anno, visible);
                if (isReferenceType(args[parameter]) && anno.equals(NOT_NULL_TYPE)) {
                    myNotNullParams.add(parameter);
                } else if (anno.equals(SYNTHETIC_TYPE)) {
                    // see http://forge.ow2.org/tracker/?aid=307392&group_id=23&atid=100023&func=detail
                    mySyntheticCount++;
                }
                return av;
            }

            @Override
            public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) {
                final AnnotationVisitor av = mv.visitAnnotation(anno, isRuntime);
                if (isReferenceType(returnType) && anno.equals(NOT_NULL_TYPE)) {
                    myIsNotNull = true;
                }

                return av;
            }

            @Override
            public void visitCode() {
                if (myNotNullParams.size() > 0) {
                    myStartGeneratedCodeLabel = new Label();
                    mv.visitLabel(myStartGeneratedCodeLabel);
                }
                for (Integer param : myNotNullParams) {
                    int var = ((access & ACC_STATIC) == 0) ? 1 : 0;
                    for (int i = 0; i < param; ++i) {
                        var += args[i].getSize();
                    }
                    mv.visitVarInsn(ALOAD, var);

                    Label end = new Label();
                    mv.visitJumpInsn(IFNONNULL, end);

                    generateThrow(IAE_CLASS_NAME,
                            String.format(NULL_ARG_MESSAGE, param - mySyntheticCount, myClassName, name), end);
                }
            }

            @Override
            public void visitLocalVariable(final String name, final String desc, final String signature,
                    final Label start, final Label end, final int index) {
                final boolean isStatic = (access & ACC_STATIC) != 0;
                final boolean isParameter = isStatic ? index < args.length : index <= args.length;
                final Label label = (isParameter && myStartGeneratedCodeLabel != null) ? myStartGeneratedCodeLabel
                        : start;
                mv.visitLocalVariable(name, desc, signature, label, end, index);
            }

            @Override
            public void visitInsn(int opcode) {
                if (opcode == ARETURN) {
                    if (myIsNotNull) {
                        mv.visitInsn(DUP);
                        final Label skipLabel = new Label();
                        mv.visitJumpInsn(IFNONNULL, skipLabel);
                        generateThrow(ISE_CLASS_NAME, String.format(NULL_RESULT_MESSAGE, myClassName, name),
                                skipLabel);
                    }
                }

                mv.visitInsn(opcode);
            }

            private void generateThrow(final String exceptionClass, final String descr, final Label end) {
                mv.visitTypeInsn(NEW, exceptionClass);
                mv.visitInsn(DUP);
                mv.visitLdcInsn(descr);
                mv.visitMethodInsn(INVOKESPECIAL, exceptionClass, CONSTRUCTOR_NAME, EXCEPTION_INIT_SIGNATURE);
                mv.visitInsn(ATHROW);
                mv.visitLabel(end);

                myIsModification = true;
                processPostponedErrors();
            }

            @Override
            public void visitMaxs(final int maxStack, final int maxLocals) {
                try {
                    super.visitMaxs(maxStack, maxLocals);
                } catch (Throwable e) {
                    registerError(name, "visitMaxs", e);
                }
            }
        };
    }

    private static boolean isReferenceType(final Type type) {
        return type.getSort() == Type.OBJECT || type.getSort() == Type.ARRAY;
    }

    private void registerError(String methodName, String operationName, Throwable e) {
        if (myPostponedError == null) {
            // throw the first error that occurred
            Throwable err = e.getCause();
            if (err == null) {
                err = e;
            }
            myPostponedError = new RuntimeException("Operation '" + operationName + "' failed for " + myClassName
                    + "." + methodName + "(): " + err.getMessage(), err);
        }
        if (myIsModification) {
            processPostponedErrors();
        }
    }

    private void processPostponedErrors() {
        final RuntimeException error = myPostponedError;
        if (error != null) {
            throw error;
        }
    }
}