com.eclecticlogic.pedal.dialect.postgresql.CopyCommand.java Source code

Java tutorial

Introduction

Here is the source code for com.eclecticlogic.pedal.dialect.postgresql.CopyCommand.java

Source

/**
 * Copyright (c) 2014-2015 Eclectic Logic LLC
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * 
 */
package com.eclecticlogic.pedal.dialect.postgresql;

import java.beans.BeanInfo;
import java.beans.IntrospectionException;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.io.StringReader;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import javassist.CannotCompileException;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtNewMethod;
import javassist.NotFoundException;

import javax.persistence.AttributeOverride;
import javax.persistence.AttributeOverrides;
import javax.persistence.Column;
import javax.persistence.Convert;
import javax.persistence.Embeddable;
import javax.persistence.EmbeddedId;
import javax.persistence.EntityManager;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.JoinColumn;

import org.postgresql.copy.CopyManager;
import org.postgresql.core.BaseConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;

import com.eclecticlogic.pedal.connection.ConnectionAccessor;
import com.eclecticlogic.pedal.provider.Consumer;
import com.eclecticlogic.pedal.provider.ProviderAccessSpi;

/**
 * WARNING: This class is in desparate need of refactoring. I envision creating a java-source code generator or using 
 * StringTemplate to clean this up. In the meantime, please accept my sincerest apologies :) 
 * 
 * Allows collection of JPA entities to be written to the database using the Postgresql COPY command.
 * 
 * @author kabram.
 *
 */
public class CopyCommand {

    private ConnectionAccessor connectionAccessor;
    private ProviderAccessSpi providerAccessSpi;

    private ConcurrentHashMap<Class<? extends Serializable>, String> fieldNamesByClass = new ConcurrentHashMap<>();
    private ConcurrentHashMap<Class<? extends Serializable>, CopyExtractor<? extends Serializable>> extractorsByClass = new ConcurrentHashMap<>();
    // This is used to prevent linkage error due to concurrent creation of classes.
    private static AtomicInteger extractorNameSuffix = new AtomicInteger();

    private static Logger logger = LoggerFactory.getLogger(CopyCommand.class);

    public void setConnectionAccessor(ConnectionAccessor connectionAccessor) {
        this.connectionAccessor = connectionAccessor;
    }

    public void setProviderAccessSpi(ProviderAccessSpi providerAccessSpi) {
        this.providerAccessSpi = providerAccessSpi;
    }

    /**
     * @param lists Entities to be inserted using the Postgres COPY command.
     */
    public <E extends Serializable> void insert(EntityManager entityManager, CopyList<E> entityList) {
        if (!entityList.isEmpty()) {
            insert(entityManager, packageCopyData(entityList));
        }
    }

    /**
     * @param entityList
     * @return Packages up the entity list as standalone data that can be serialized and inserted at a later point.
     */
    @SuppressWarnings("unchecked")
    public <E extends Serializable> PackagedCopyData packageCopyData(CopyList<E> entityList) {
        Class<? extends Serializable> clz = entityList.get(0).getClass();
        setupFor(clz);
        String fieldNames = fieldNamesByClass.get(clz);
        CopyExtractor<E> extractor = (CopyExtractor<E>) extractorsByClass.get(clz);
        StringBuilder builder = new StringBuilder(1024 * entityList.size());
        for (E entity : entityList) {
            builder.append(extractor.getValueList(entity));
            builder.append("\n");
        }

        PackagedCopyData data = new PackagedCopyData();
        data.setBody(builder.toString());
        data.setSize(entityList.size());
        data.setCopySql("copy " + getEntityName(entityList) + "(" + fieldNames + ") from stdin");
        return data;
    }

    /**
     * Inserts packaged copy data using the COPY command.
     * @param entityManager
     * @param data
     */
    public void insert(EntityManager entityManager, final PackagedCopyData data) {
        final StringReader reader = new StringReader(data.getBody());
        providerAccessSpi.run(entityManager, new Consumer<Connection>() {

            @Override
            public void accept(Connection connection) {
                try {
                    CopyManager copyManager = new CopyManager(
                            (BaseConnection) connectionAccessor.getRawConnection(connection));
                    long t1 = System.currentTimeMillis();
                    copyManager.copyIn(data.getCopySql(), reader);
                    long elapsedTime = System.currentTimeMillis() - t1;
                    logger.info("Wrote {} inserts in {} seconds", data.getSize(),
                            Math.round(elapsedTime / 10.0) / 100.0);
                } catch (Exception e) {
                    logger.trace("Command passed: {}", data);
                    throw new RuntimeException(e.getMessage(), e);
                }
            }
        });

    }

    @SuppressWarnings("unchecked")
    private void setupFor(Class<? extends Serializable> clz) {
        if (fieldNamesByClass.get(clz) == null) {
            List<String> fields = new ArrayList<>();
            List<Method> methods = new ArrayList<>();
            for (Method method : clz.getMethods()) {
                String columnName = null;
                if (method.isAnnotationPresent(Id.class) && method.isAnnotationPresent(GeneratedValue.class)
                        && method.getAnnotation(GeneratedValue.class).strategy() == GenerationType.IDENTITY) {
                    // Ignore pk with identity strategy.
                } else if (method.isAnnotationPresent(Column.class)) {
                    columnName = extractColumnName(method, clz);
                } else if (method.isAnnotationPresent(JoinColumn.class)
                        && method.getAnnotation(JoinColumn.class).insertable()) {
                    columnName = method.getAnnotation(JoinColumn.class).name();
                } else if (method.isAnnotationPresent(EmbeddedId.class)) {
                    // Handle Attribute override annotation ...
                    if (method.isAnnotationPresent(AttributeOverrides.class)) {
                        AttributeOverrides overrides = method.getAnnotation(AttributeOverrides.class);
                        for (AttributeOverride override : overrides.value()) {
                            fields.add(override.column().name());
                        }
                    }
                    methods.add(method);
                } else if (method.getReturnType().isAnnotationPresent(Embeddable.class)) {
                    methods.add(method);
                    // Embeddables can have column names remapped using attribute overrides.
                    // We expect all attributes to be overridden. 
                    if (method.isAnnotationPresent(AttributeOverrides.class)) {
                        AttributeOverrides overrides = method.getAnnotation(AttributeOverrides.class);
                        for (AttributeOverride overrideAnnotation : overrides.value()) {
                            fields.add(overrideAnnotation.column().name());
                        }
                    } else {
                        // Use the column names from the embeddable class. Assumption: @Column annotation exists.
                        for (Method embedded : method.getReturnType().getMethods()) {
                            if (embedded.isAnnotationPresent(Column.class)) {
                                fields.add(extractColumnName(embedded,
                                        (Class<? extends Serializable>) method.getReturnType()));
                            }
                        }
                    }
                }
                if (columnName != null) {
                    // Certain one-to-on join situations can lead to multiple columns with the same column-name.
                    if (!fields.contains(columnName)) {
                        fields.add(columnName);
                        methods.add(method);
                    }
                } // end if annotation present
            }
            extractorsByClass.put(clz, getExtractor(clz, methods));
            StringBuilder builder = new StringBuilder();
            for (String field : fields) {
                builder.append(",").append(field);
            }
            fieldNamesByClass.put(clz, builder.substring(1));
        }
    }

    private String extractColumnName(Method method, Class<? extends Serializable> clz) {
        String beanPropertyName = null;
        try {
            BeanInfo info;

            info = Introspector.getBeanInfo(clz);

            for (PropertyDescriptor propDesc : info.getPropertyDescriptors()) {
                if (method.equals(propDesc.getReadMethod())) {
                    beanPropertyName = propDesc.getName();
                    break;
                }
            }
        } catch (IntrospectionException e) {
            throw new RuntimeException(e.getMessage(), e);
        }

        String columnName = null;
        if (clz.isAnnotationPresent(AttributeOverrides.class)) {
            for (AttributeOverride annotation : clz.getAnnotation(AttributeOverrides.class).value()) {
                if (annotation.name().equals(beanPropertyName)) {
                    columnName = annotation.column().name();
                    break;
                }
            }
        } else if (clz.isAnnotationPresent(AttributeOverride.class)) {
            AttributeOverride annotation = clz.getAnnotation(AttributeOverride.class);
            if (annotation.name().equals(beanPropertyName)) {
                columnName = annotation.column().name();
            }
        }
        return columnName == null ? method.getAnnotation(Column.class).name() : columnName;
    }

    /**
     * Note: This method is now "ugly" and needs serious refactoring.
     * @param clz
     * @param fieldMethods
     * @return Create a class to generate the copy row strings.
     */
    @SuppressWarnings({ "unchecked" })
    private <E extends Serializable> CopyExtractor<E> getExtractor(Class<E> clz, List<Method> fieldMethods) {
        ClassPool pool = ClassPool.getDefault();
        pool.insertClassPath(new ClassClassPath(clz));
        CtClass cc = pool.makeClass("com.eclecticlogic.pedal.dialect.postgresql." + clz.getSimpleName()
                + "$CopyExtractor_" + extractorNameSuffix.incrementAndGet());

        StringBuilder methodBody = new StringBuilder();
        try {
            cc.addInterface(pool.getCtClass(CopyExtractor.class.getName()));

            methodBody.append("public String getValueList(Object entity) {\n");
            methodBody.append("try {\n");
            methodBody.append("StringBuilder builder = new StringBuilder();\n");
            methodBody.append(clz.getName() + " typed = (" + clz.getName() + ")entity;\n");
            for (int i = 0; i < fieldMethods.size(); i++) {
                Method method = fieldMethods.get(i);
                if (method.isAnnotationPresent(CopyConverter.class)) {
                    try {
                        Class<?> helper = method.getAnnotation(CopyConverter.class).value();
                        methodBody.append(helper.getName() + " c" + i + " = (" + helper.getName() + ")"
                                + helper.getName() + ".class.newInstance();\n");
                        methodBody
                                .append("builder.append(c" + i + ".convert(typed." + method.getName() + "()));\n");
                    } catch (Exception e) {
                        throw new RuntimeException(e.getMessage(), e);
                    }
                } else if (method.getReturnType().isPrimitive()) {
                    methodBody.append("builder.append(typed." + method.getName() + "());\n");
                } else {
                    methodBody.append(
                            method.getReturnType().getName() + " v" + i + " = typed." + method.getName() + "();\n");
                    if (method.isAnnotationPresent(EmbeddedId.class)) {
                        // Embedded id
                        if (method.isAnnotationPresent(AttributeOverrides.class)) {
                            AttributeOverrides overrides = method.getAnnotation(AttributeOverrides.class);
                            for (int j = 0; j < overrides.value().length; j++) {
                                AttributeOverride override = overrides.value()[j];
                                methodBody.append("if (v" + i + " == null) {builder.append(\"\\\\N\");}\n");
                                methodBody.append("else {\n");
                                Method idMethod = BeanUtils
                                        .getPropertyDescriptor(method.getReturnType(), override.name())
                                        .getReadMethod();
                                methodBody.append("builder.append(v" + i + "." + idMethod.getName() + "());\n");
                                methodBody.append("}\n");
                                if (j != overrides.value().length - 1) {
                                    methodBody.append("builder.append(\"\\t\");\n");
                                }
                            }
                        }
                    } else if (method.getReturnType().isAnnotationPresent(Embeddable.class)) {
                        // Embeddable. If attribute overrides is present, we get the columns in the order of the 
                        // override annotations. Otherwise use method names.
                        if (method.isAnnotationPresent(AttributeOverrides.class)) {
                            AttributeOverrides overrides = method.getAnnotation(AttributeOverrides.class);
                            for (int j = 0; j < overrides.value().length; j++) {
                                AttributeOverride override = overrides.value()[j];
                                Method embedded = BeanUtils
                                        .getPropertyDescriptor(method.getReturnType(), override.name())
                                        .getReadMethod();
                                if (embedded.getReturnType().isPrimitive()) {
                                    methodBody.append("builder.append(v" + i + "." + embedded.getName() + "());");
                                } else {
                                    methodBody.append("if (v" + i + "." + embedded.getName()
                                            + "() == null) {builder.append(\"\\\\N\");}");
                                    methodBody.append(
                                            "else {builder.append(v" + i + "." + embedded.getName() + "());}");
                                }
                                if (j != overrides.value().length - 1) {
                                    methodBody.append("builder.append(\"\\t\");\n");
                                }
                            } // end loop over override values.
                        } else {
                            // Get columns as discovered by method traversal. Its assumed that @Column annotation exists
                            // on getter methods of interest.
                            List<Method> embeddedColumnMethods = new ArrayList<>();
                            for (Method embedded : method.getReturnType().getMethods()) {
                                if (embedded.isAnnotationPresent(Column.class)) {
                                    embeddedColumnMethods.add(embedded);
                                }
                            }

                            for (int j = 0; j < embeddedColumnMethods.size(); j++) {
                                Method embedded = embeddedColumnMethods.get(j);
                                if (embedded.getReturnType().isPrimitive()) {
                                    methodBody.append("builder.append(v" + i + "." + embedded.getName() + "());");
                                } else {
                                    methodBody.append("if (v" + i + "." + embedded.getName()
                                            + "() == null) {builder.append(\"\\\\N\");}");
                                    methodBody.append(
                                            "else {builder.append(v" + i + "." + embedded.getName() + "());}");
                                }
                                if (j != embeddedColumnMethods.size() - 1) {
                                    methodBody.append("builder.append(\"\\t\");\n");
                                }
                            } // end loop over override values.

                        }
                    } else {
                        methodBody.append("if (v" + i + " == null) {builder.append(\"\\\\N\");}\n");
                        methodBody.append("else {\n");

                        if (method.isAnnotationPresent(CopyAsBitString.class)) {
                            // Get number of bits from @Column annotation.
                            int bitCount = method.getAnnotation(Column.class).length();
                            methodBody.append("java.util.BitSet bs" + i + " = typed." + method.getName() + "();\n");
                            methodBody.append("for (int i = 0; i < " + bitCount + "; i++) {\n");
                            methodBody.append("builder.append(bs" + i + ".get(i) ? \"0\" : \"1\");\n");
                            methodBody.append("}\n");
                        } else if (Collection.class.isAssignableFrom(method.getReturnType())
                                && method.isAnnotationPresent(Convert.class) == false) {
                            // Postgreql array type.
                            if (method.isAnnotationPresent(CopyEmptyAsNull.class)) {
                                methodBody.append("if (typed." + method.getName() + "().isEmpty()) {\n");
                                methodBody.append("builder.append(\"\\\\N\");\n");
                                methodBody.append("} else {\n");
                                collectionExtractor(methodBody, i, method);
                                methodBody.append("}\n");
                            } else {
                                collectionExtractor(methodBody, i, method);
                            }
                        } else if (method.isAnnotationPresent(Convert.class)) {
                            Class<?> converterClass = method.getAnnotation(Convert.class).converter();
                            methodBody
                                    .append(converterClass.getName() + " c" + i + " = (" + converterClass.getName()
                                            + ")" + converterClass.getName() + ".class.newInstance();\n");
                            methodBody.append("builder.append(c" + i + ".convertToDatabaseColumn(v" + i + "));\n");
                        } else if (method.isAnnotationPresent(JoinColumn.class)) {
                            // We need to get the id of the joined object.
                            for (Method method2 : method.getReturnType().getMethods()) {
                                if (method2.isAnnotationPresent(Id.class)) {
                                    methodBody.append("builder.append(v" + i + "." + method2.getName() + "());\n");
                                }
                            }
                        } else {
                            methodBody.append("builder.append(v" + i + ");\n");
                        }
                        methodBody.append("}\n");
                    }
                }
                if (i != fieldMethods.size() - 1) {
                    methodBody.append("builder.append(\"\\t\");\n");
                }
            }
            methodBody.append("return builder.toString();\n");
            methodBody.append("} catch (Exception e) { throw new RuntimeException(e); } \n");
            methodBody.append("}\n");
            logger.trace(methodBody.toString());
            cc.addMethod(CtNewMethod.make(methodBody.toString(), cc));
        } catch (NotFoundException | CannotCompileException e) {
            logger.error(e.getMessage(), "Compiled body is:\n" + methodBody.toString());
            throw new RuntimeException(e.getMessage(), e);
        }

        try {
            return (CopyExtractor<E>) cc.toClass().newInstance();
        } catch (InstantiationException | IllegalAccessException | CannotCompileException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    private void collectionExtractor(StringBuilder methodBody, int i, Method method) {
        methodBody.append("java.util.Iterator it" + i + " = typed." + method.getName() + "().iterator();\n");
        methodBody.append("StringBuilder array" + i + " = new StringBuilder();\n");
        methodBody.append("while (it" + i + ".hasNext()) {\n");
        methodBody.append("Object o = it" + i + ".next();\n");
        methodBody.append("array" + i + ".append(\",\").append(o);\n");
        methodBody.append("}\n");
        methodBody.append(
                "String arrayStr" + i + " = array" + i + ".length() == 0 ? \"\" : array" + i + ".substring(1);\n");
        methodBody.append("builder.append(\"{\").append(arrayStr" + i + ").append(\"}\");\n");
    }

    private <E extends Serializable> String getEntityName(CopyList<E> copyList) {
        String alternateName = copyList.getAlternateTableName();
        if (alternateName == null || alternateName.trim().length() == 0) {
            return providerAccessSpi.getTableName(copyList.get(0).getClass());
        } else {
            String schemaName = providerAccessSpi.getSchemaName();
            if (schemaName == null || schemaName.trim().length() == 0) {
                return copyList.getAlternateTableName();
            } else {
                return schemaName + "." + copyList.getAlternateTableName();
            }
        }
    }

}