net.firejack.platform.core.utils.db.DBUtils.java Source code

Java tutorial

Introduction

Here is the source code for net.firejack.platform.core.utils.db.DBUtils.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package net.firejack.platform.core.utils.db;

import net.firejack.platform.api.registry.model.LogLevel;
import net.firejack.platform.core.exception.BusinessFunctionException;
import net.firejack.platform.core.model.registry.DatabaseName;
import net.firejack.platform.core.utils.OpenFlameDataSource;
import net.firejack.platform.core.utils.OpenFlameSpringContext;
import net.firejack.platform.core.utils.StringUtils;
import net.firejack.platform.model.service.reverse.analyzer.AbstractTableAnalyzer;
import net.firejack.platform.model.service.reverse.analyzer.MSSQLTableAnalyzer;
import net.firejack.platform.model.service.reverse.analyzer.MySQLTableAnalyzer;
import net.firejack.platform.model.service.reverse.analyzer.OracleTableAnalyzer;
import net.firejack.platform.model.service.reverse.bean.Column;
import net.firejack.platform.model.service.reverse.bean.Table;
import net.firejack.platform.model.service.reverse.bean.TablesMapping;
import net.firejack.platform.web.mina.aop.ManuallyProgress;
import org.apache.commons.io.IOUtils;
import org.apache.log4j.Logger;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCallback;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import javax.sql.DataSource;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.sql.*;
import java.util.*;

public class DBUtils {

    public static final int DEFAULT_BATCH_SIZE = 50;
    private static final Logger logger = Logger.getLogger(DBUtils.class);
    private static PreparedStatementCallback<Boolean> stubCallback = new StatementCallback();

    /**
     * @param dataSource
     * @param dbName
     * @return
     */
    public static boolean dbExists(DataSource dataSource, String dbName) {
        NamedParameterJdbcTemplate template = new NamedParameterJdbcTemplate(dataSource);
        SqlParameterSource namedParameters = new MapSqlParameterSource("dbName", dbName);
        String sql = "SELECT count(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = :dbName";
        int count = template.queryForInt(sql, namedParameters);
        return count > 0;
    }

    /**
     * @param dataSource
     * @param databaseToCreate
     * @return
     */
    public static boolean createDatabase(DataSource dataSource, String databaseToCreate) {
        return executeStatement(dataSource,
                "CREATE DATABASE IF NOT EXISTS " + StringUtils.wrapWith("`", databaseToCreate) + ";");
    }

    /**
     * @param dataSource
     * @param databaseToDrop
     * @return
     */
    public static boolean dropDatabase(DataSource dataSource, String databaseToDrop) {
        return executeStatement(dataSource,
                "DROP DATABASE IF EXISTS " + StringUtils.wrapWith("`", databaseToDrop) + ";");
    }

    /**
     * @param dataSource
     * @param sql
     * @return
     */
    public static boolean executeStatement(DataSource dataSource, String sql) {
        JdbcTemplate template = new JdbcTemplate(dataSource);
        try {
            template.execute(sql);
        } catch (DataAccessException e) {
            logger.error(e.getMessage(), e);
            return false;
        }
        return true;
    }

    /**
     * @param dataSource
     * @param sqlScript
     * @param delimiter
     * @return
     */
    public static boolean executeScript(DataSource dataSource, File sqlScript, String delimiter) {
        if (sqlScript == null || !sqlScript.exists()) {
            throw new IllegalArgumentException("Wrong script file parameter");
        }
        String sql;
        try {
            sql = IOUtils.toString(new FileReader(sqlScript));
        } catch (IOException e) {
            logger.error(e.getMessage(), e);
            return false;
        }

        String[] sqls = sql.split(delimiter);
        if (sqls.length == 1) {
            sqls = sql.split(";");
        }
        for (String sqlStatement : sqls) {
            sqlStatement = sqlStatement.trim();
            if (!sqlStatement.isEmpty() && !sqlStatement.startsWith("#")) {
                if (!executeStatement(dataSource, sqlStatement)) {
                    return false;
                }
            }
        }
        return true;
    }

    /**
     * @param sql
     * @param dataSource
     * @param rsExtractor
     * @return
     * @throws CommonDBAccessException
     */
    public static <T> T query(String sql, DataSource dataSource, ResultSetExtractor<T> rsExtractor)
            throws CommonDBAccessException {
        JdbcTemplate template = new JdbcTemplate(dataSource);
        try {
            return template.query(sql, rsExtractor);
        } catch (DataAccessException e) {
            logger.error(e.getMessage(), e);
            throw new CommonDBAccessException(e);
        }
    }

    /**
     * @param sql
     * @param dataSource
     * @param rowMapper
     * @return
     * @throws CommonDBAccessException
     */
    public static <T> List<T> query(String sql, DataSource dataSource, RowMapper<T> rowMapper)
            throws CommonDBAccessException {
        JdbcTemplate template = new JdbcTemplate(dataSource);
        try {
            return template.query(sql, rowMapper);
        } catch (DataAccessException e) {
            logger.error(e.getMessage(), e);
            throw new CommonDBAccessException(e);
        }
    }

    /**
     * @param sql
     * @param dataSource
     * @return
     * @throws CommonDBAccessException
     */
    public static Integer querySingleInt(String sql, DataSource dataSource) throws CommonDBAccessException {
        return query(sql, dataSource, new SingleIntegerRSExtractor());
    }

    /**
     * @param sql
     * @param dataSource
     * @return
     * @throws CommonDBAccessException
     */
    public static String querySingleString(String sql, DataSource dataSource) throws CommonDBAccessException {
        return query(sql, dataSource, new SingleStringRSExtractor());
    }

    public static void migrateData(OpenFlameDataSource source, OpenFlameDataSource target) {
        logger.info("Starting migration process");
        Connection sourceConnection = null;
        Connection targetConnection = null;
        try {
            sourceConnection = source.getConnection();
            targetConnection = target.getConnection();
            List<Table> mysqlTables = getTables(sourceConnection, source);
            List<Table> oracleTables = getTables(targetConnection, target);
            List<TablesMapping> mappings = mapTables(mysqlTables, oracleTables);

            if (mappings.size() == mysqlTables.size()) {
                logger.info("Tables mapped successfully.");
                for (TablesMapping mapping : mappings) {
                    addLog("Migrate data from [" + mapping.getSourceTable().getName() + "] table.", 1,
                            LogLevel.INFO);
                    insertDataToTargetTable(mapping, sourceConnection, targetConnection);
                    addLog("Data migration from [" + mapping.getSourceTable().getName() + "] completed.", 1,
                            LogLevel.INFO);
                }
            } else {
                logger.warn("Failed to map all tables.");
                addLog("Failed to migrate data - databases are not identical.", 1, LogLevel.ERROR);
            }
        } catch (SQLException e) {
            logger.error(e.getMessage(), e);
            addLog(e.getMessage(), 1, LogLevel.ERROR);
        } finally {
            if (sourceConnection != null) {
                try {
                    sourceConnection.close();
                } catch (SQLException e) {
                    logger.error(e.getMessage(), e);
                }
            }
            if (targetConnection != null) {
                try {
                    targetConnection.close();
                } catch (SQLException e) {
                    logger.error(e.getMessage(), e);
                }
            }
        }
    }

    /**
     * @param driverClassName driver classname
     * @param url jdbc url
     * @param user db user
     * @param password db user password
     * @return returns populated data-source
     */
    public static DataSource populateDataSource(String driverClassName, String url, String user, String password) {
        DriverManagerDataSource dataSource = new DriverManagerDataSource();
        dataSource.setDriverClassName(driverClassName);
        dataSource.setUrl(url);
        dataSource.setUsername(user);
        dataSource.setPassword(password);
        return dataSource;
    }

    private static void addLog(String message, int weight, LogLevel logLevel) {
        ManuallyProgress progress = OpenFlameSpringContext.getBean(ManuallyProgress.class);
        if (progress != null) {
            progress.status(message, weight, logLevel);
        }
    }

    private static class StatementCallback implements PreparedStatementCallback<Boolean> {
        @Override
        public Boolean doInPreparedStatement(PreparedStatement preparedStatement)
                throws SQLException, DataAccessException {
            return preparedStatement.execute();
        }
    }

    private static List<TablesMapping> mapTables(List<Table> sourceTables, List<Table> targetTables) {
        List<TablesMapping> mapping = new ArrayList<TablesMapping>(sourceTables.size());
        for (Table sourceTable : sourceTables) {
            Table correspondingTargetTable = null;
            for (Table targetTable : targetTables) {
                if (sourceTable.getName().equalsIgnoreCase(targetTable.getName())) {
                    correspondingTargetTable = targetTable;
                    break;
                }
            }
            if (correspondingTargetTable == null) {
                logger.error("Failed to locate corresponding target table. Source table - [" + sourceTable.getName()
                        + "].");
            } else {
                List<Column> sourceTableColumns = sourceTable.getColumns();
                List<Column> targetTableColumns = correspondingTargetTable.getColumns();
                Map<Column, Column> columnMapping = new HashMap<Column, Column>();
                for (Column sourceColumn : sourceTableColumns) {
                    Column correspondingTargetColumn = null;
                    for (Column targetColumn : targetTableColumns) {
                        if (sourceColumn.getName().equalsIgnoreCase(targetColumn.getName())) {
                            correspondingTargetColumn = targetColumn;
                        }
                    }
                    if (correspondingTargetColumn == null) {
                        logger.error("Failed to locate corresponding target column, Source column - ["
                                + sourceColumn.getName() + "].");
                        columnMapping = null;
                        break;
                    } else {
                        columnMapping.put(sourceColumn, correspondingTargetColumn);
                    }
                }
                if (columnMapping != null) {
                    mapping.add(new TablesMapping(sourceTable, correspondingTargetTable, columnMapping));
                }
            }
        }
        return mapping;
    }

    private static void insertDataToTargetTable(TablesMapping mapping, Connection sourceConnection,
            Connection targetConnection) throws SQLException {
        Map<Column, Column> columnMapping = mapping.getColumnMapping();
        if (columnMapping.isEmpty()) {
            logger.warn("No columns are detected - no data to insert.");
        } else {
            ResultSet rs = selectDataFromSource(sourceConnection, mapping);

            String insertQuery = populateInsertQuery(mapping);
            PreparedStatement insertStatement = targetConnection.prepareStatement(insertQuery);
            targetConnection.setAutoCommit(false);
            try {
                int currentStep = 1;
                while (rs.next()) {
                    for (int i = 1; i <= columnMapping.size(); i++) {
                        insertStatement.setObject(i, rs.getObject(i));
                    }
                    insertStatement.addBatch();
                    if (++currentStep > DEFAULT_BATCH_SIZE) {
                        insertStatement.executeBatch();
                        targetConnection.commit();
                        currentStep = 1;
                    }
                }
                if (currentStep != 1) {
                    insertStatement.executeBatch();
                    targetConnection.commit();
                }
            } catch (SQLException e) {
                logger.error(e.getMessage(), e);
                targetConnection.rollback();
            } finally {
                insertStatement.close();
                rs.close();
            }
        }
    }

    private static ResultSet selectDataFromSource(Connection sourceConnection, TablesMapping mapping)
            throws SQLException {
        Map<Column, Column> columnMapping = mapping.getColumnMapping();
        StringBuilder selectQuery = new StringBuilder("select ");
        for (Map.Entry<Column, Column> columnEntry : columnMapping.entrySet()) {
            Column sourceColumn = columnEntry.getKey();
            selectQuery.append(sourceColumn.getName()).append(',');
        }
        if (!columnMapping.isEmpty()) {
            selectQuery.replace(selectQuery.length() - 1, selectQuery.length(), "");
        }
        selectQuery.append(" from ").append(mapping.getSourceTable().getName());
        String sql = selectQuery.toString();
        Statement statement = sourceConnection.createStatement(ResultSet.TYPE_FORWARD_ONLY,
                ResultSet.CONCUR_READ_ONLY);
        ResultSet rs = statement.executeQuery(sql);
        rs.setFetchSize(DEFAULT_BATCH_SIZE);
        return rs;
    }

    private static String populateInsertQuery(TablesMapping mapping) {
        Map<Column, Column> columnMapping = mapping.getColumnMapping();
        StringBuilder insertQuery = new StringBuilder("insert into ");
        insertQuery.append(mapping.getTargetTable().getName()).append('(');
        Set<Map.Entry<Column, Column>> columnEntries = columnMapping.entrySet();
        for (Map.Entry<Column, Column> entry : columnEntries) {
            Column targetColumn = entry.getValue();
            insertQuery.append(targetColumn.getName()).append(',');
        }
        if (!columnMapping.isEmpty()) {
            insertQuery.replace(insertQuery.length() - 1, insertQuery.length(), "");
        }
        insertQuery.append(") values (");
        for (int i = 0; i < columnEntries.size(); i++) {
            insertQuery.append(i == 0 ? '?' : ", ?");
        }
        insertQuery.append(')');
        return insertQuery.toString();
    }

    private static List<Table> getTables(Connection connection, OpenFlameDataSource dataSource)
            throws SQLException {
        List<Table> tableNames = new ArrayList<Table>();
        DatabaseMetaData metaData = connection.getMetaData();
        ResultSet rs;
        if (dataSource.getName() == DatabaseName.Oracle) {
            rs = metaData.getTables(dataSource.getSid(), dataSource.getSchema(), null, new String[] { "TABLE" });
        } else {
            rs = metaData.getTables(null, null, null, new String[] { "TABLE" });
        }
        while (rs.next()) {
            String tableType = rs.getString("TABLE_TYPE");
            if ("TABLE".equals(tableType)) {
                String tableName = rs.getString("TABLE_NAME");
                Table table = new Table();
                table.setName(tableName);
                List<Column> columns = getColumns(dataSource, metaData, table);
                table.setColumns(columns);
                tableNames.add(table);
            }
        }
        return tableNames;
    }

    private static List<Column> getColumns(OpenFlameDataSource dataSource, DatabaseMetaData metaData, Table table)
            throws SQLException {
        ResultSet rs;
        if (dataSource.getName() == DatabaseName.Oracle) {
            rs = metaData.getColumns("orcl", "TIMUR", table.getName(), null);
        } else {
            rs = metaData.getColumns(null, null, table.getName(), null);
        }
        List<Column> columns = new ArrayList<Column>();
        AbstractTableAnalyzer dbAnalyzer = getDBAnalyzer(dataSource);
        while (rs.next()) {
            Column column = dbAnalyzer.createColumn(rs, table);
            columns.add(column);
        }
        return columns;
    }

    private static AbstractTableAnalyzer getDBAnalyzer(OpenFlameDataSource ds) {
        AbstractTableAnalyzer dbAnalyzer;
        switch (ds.getName()) {
        case MySQL:
            dbAnalyzer = new MySQLTableAnalyzer(ds, ds.getSchema());
            break;
        case Oracle:
            dbAnalyzer = new OracleTableAnalyzer(ds, ds.getSid(), ds.getSchema());
            break;
        case MSSQL:
            dbAnalyzer = new MSSQLTableAnalyzer(ds, ds.getSchema(), ds.getSid());
            break;
        default:
            throw new BusinessFunctionException(
                    "Could not produce db analyzer for database of type " + ds.getName().name());
        }
        return dbAnalyzer;
    }

}