com.haulmont.cuba.core.sys.dbupdate.DbUpdaterEngine.java Source code

Java tutorial

Introduction

Here is the source code for com.haulmont.cuba.core.sys.dbupdate.DbUpdaterEngine.java

Source

/*
 * Copyright (c) 2008-2016 Haulmont.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package com.haulmont.cuba.core.sys.dbupdate;

import com.haulmont.bali.db.DbUtils;
import com.haulmont.bali.db.QueryRunner;
import com.haulmont.cuba.core.sys.DbInitializationException;
import com.haulmont.cuba.core.sys.DbUpdater;
import com.haulmont.cuba.core.sys.PostUpdateScripts;
import com.haulmont.cuba.core.sys.persistence.DbmsSpecificFactory;
import groovy.lang.Binding;
import groovy.lang.Closure;
import groovy.lang.GroovyShell;
import groovy.lang.Script;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.lang.StringUtils;
import org.codehaus.groovy.control.CompilerConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import javax.sql.DataSource;
import java.io.IOException;
import java.sql.*;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class DbUpdaterEngine implements DbUpdater {
    private static final String SQL_EXTENSION = "sql";
    private static final String SQL_COMMENT_PREFIX = "--";
    private static final String SQL_DELIMITER = "^";

    private static final String GROOVY_EXTENSION = "groovy";
    protected static final String UPGRADE_GROOVY_EXTENSION = "upgrade.groovy";

    protected static final String ERROR = "\n" + "=================================================\n"
            + "ERROR: Database update failed. See details below.\n"
            + "=================================================\n";

    protected static final Logger log = LoggerFactory.getLogger(DbUpdaterEngine.class);

    protected DataSource dataSource;

    protected String dbScriptsDirectory;
    protected String dbmsType;
    protected String dbmsVersion;

    protected boolean changelogTableExists = false;

    // register handlers for script files
    protected final Map<String, FileHandler> extensionHandlers = new HashMap<>();

    {
        extensionHandlers.put(SQL_EXTENSION, this::executeSqlScript);
        extensionHandlers.put(GROOVY_EXTENSION, this::executeGroovyScript);
    }

    protected DbUpdaterEngine() {
    }

    public DataSource getDataSource() {
        return dataSource;
    }

    public List<ScriptResource> getScripts(ScriptType scriptType, @Nullable String moduleName) {
        List<ScriptResource> scripts = scriptScanner().getScripts(scriptType, moduleName);
        log.trace("Found {} scripts: {}", scriptType, scripts);
        return scripts;
    }

    @Override
    public void updateDatabase() throws DbInitializationException {
        if (dbInitialized())
            doUpdate();
        else
            doInit();
    }

    @Override
    public List<String> findUpdateDatabaseScripts() throws DbInitializationException {
        List<String> list = new ArrayList<>();
        if (dbInitialized()) {
            if (!changelogTableExists) {
                throw new DbInitializationException(
                        "Unable to determine required updates because SYS_DB_CHANGELOG table doesn't exist");
            } else {
                List<ScriptResource> files = getUpdateScripts();
                Set<String> scripts = getExecutedScripts();
                for (ScriptResource file : files) {
                    String name = getScriptName(file);
                    if (!scripts.contains(name)) {
                        list.add(name);
                    }
                }
            }
        } else {
            throw new DbInitializationException(
                    "Unable to determine required updates because SEC_USER table doesn't exist");
        }
        return list;
    }

    protected ScriptScanner scriptScanner() {
        return new ScriptScanner(dbScriptsDirectory, dbmsType, dbmsVersion);
    }

    protected String dbScriptDirectoryPath() {
        return scriptScanner().dbScriptDirectoryPath();
    }

    protected void createChangelogTable() {
        log.trace("Creating SYS_DB_CHANGELOG table");
        String timeStampType = DbmsSpecificFactory.getDbmsFeatures().getTimeStampType();
        QueryRunner runner = new QueryRunner(getDataSource());
        try {
            int pkLength = "mysql".equals(dbmsType) ? 255 : 300;
            runner.update("create table SYS_DB_CHANGELOG(" + "SCRIPT_NAME varchar(" + pkLength
                    + ") not null primary key, " + "CREATE_TS " + timeStampType + " default current_timestamp, "
                    + "IS_INIT integer default 0)");
        } catch (SQLException e) {
            throw new RuntimeException(ERROR + "Error creating changelog table", e);
        }
    }

    protected String getScriptName(ScriptResource resource) {
        String path = resource.getPath();
        return getScriptName(path);
    }

    protected String getScriptName(String path) {
        String dir = dbScriptDirectoryPath();
        path = path.replace("\\", "/");
        int indexOfDir = path.indexOf(dir);
        return path.substring(indexOfDir + dir.length() + 1).replaceAll("^/+", "");
    }

    public boolean dbInitialized() throws DbInitializationException {
        log.trace("Checking if the database is initialized");
        Connection connection = null;
        try {
            connection = getDataSource().getConnection();
            DatabaseMetaData dbMetaData = connection.getMetaData();
            DbProperties dbProperties = new DbProperties(getConnectionUrl(connection));
            boolean isSchemaByUser = DbmsSpecificFactory.getDbmsFeatures().isSchemaByUser();
            String schemaName = isSchemaByUser ? dbMetaData.getUserName() : dbProperties.getCurrentSchemaProperty();
            ResultSet tables = dbMetaData.getTables(null, schemaName, null, null);
            boolean found = false;
            while (tables.next()) {
                String tableName = tables.getString("TABLE_NAME");
                if ("SYS_DB_CHANGELOG".equalsIgnoreCase(tableName)) {
                    log.trace("Found SYS_DB_CHANGELOG table");
                    changelogTableExists = true;
                }
                if ("SEC_USER".equalsIgnoreCase(tableName)) {
                    log.trace("Found SEC_USER table");
                    found = true;
                }
            }
            return found;
        } catch (SQLException e) {
            throw new DbInitializationException(true, "Error connecting to database: " + e.getMessage(), e);
        } finally {
            if (connection != null)
                try {
                    connection.close();
                } catch (SQLException ignored) {
                }
        }
    }

    protected void doInit() {
        log.info("Initializing database");

        if (!changelogTableExists)
            createChangelogTable();

        List<ScriptResource> initFiles = getInitScripts();
        try {
            for (ScriptResource file : initFiles) {
                executeScript(file);
                markScript(getScriptName(file), true);
            }
        } finally {
            prepareScripts();
        }

        log.info("Database initialized");
    }

    protected void doUpdate() {
        log.info("Updating database...");

        if (!changelogTableExists) {
            log.info("Changelog table not found, creating it and marking all scripts as executed");

            createChangelogTable();

            List<ScriptResource> initFiles = getInitScripts();
            try {
                for (ScriptResource file : initFiles) {
                    markScript(getScriptName(file), true);
                }
            } finally {
                prepareScripts();
            }

            return;
        }

        runRequiredInitScripts();

        log.trace("Checking existing and executed update scripts");
        List<ScriptResource> files = getUpdateScripts();
        Set<String> scripts = getExecutedScripts();
        for (ScriptResource file : files) {
            String name = getScriptName(file);
            if (!containsIgnoringPrefix(scripts, name)) {
                if (executeScript(file)) {
                    markScript(name, false);
                }
            }
        }
        log.info("Database is up-to-date");
    }

    protected void runRequiredInitScripts() {
        log.trace("Checking executed init scripts for components");
        Set<String> executedScripts = getExecutedScripts();
        List<String> dirs = getModuleDirs();
        if (dirs.size() > 1) {
            // check all db folders except the last because it is the folder of the app and we need only components
            for (String dirName : dirs.subList(0, dirs.size() - 1)) {
                List<ScriptResource> initScripts = getInitScripts(dirName);
                if (!initScripts.isEmpty()) {
                    boolean anInitScriptHasBeenExecuted = false;
                    for (ScriptResource initScript : initScripts) {
                        String initScriptName = getScriptName(initScript);
                        log.trace("Checking script {}", initScriptName);
                        if (containsIgnoringPrefix(executedScripts, initScriptName)) {
                            anInitScriptHasBeenExecuted = true;
                            break;
                        }
                    }
                    if (!anInitScriptHasBeenExecuted && !initializedByOwnScript(executedScripts, dirName)) {
                        log.info(
                                "No init scripts from " + dirName + " have been executed, running init scripts for "
                                        + distinguishingSubstring(dirName));
                        try {
                            for (ScriptResource file : initScripts) {
                                executeScript(file);
                                markScript(getScriptName(file), true);
                            }
                        } finally {
                            List<ScriptResource> updateFiles = getUpdateScripts(dirName);
                            for (ScriptResource file : updateFiles)
                                markScript(getScriptName(file), true);
                        }
                    }
                }
            }
        }
    }

    protected boolean initializedByOwnScript(Set<String> executedScripts, String dirName) {
        boolean found = executedScripts.stream().anyMatch(s -> s.substring(s.lastIndexOf('/') + 1)
                .equalsIgnoreCase("01." + dirName.substring(3) + "-create-db.sql"));
        if (found)
            log.debug("Found executed '01." + dirName.substring(3) + "-create-db.sql' script");
        return found;
    }

    protected boolean containsIgnoringPrefix(Collection<String> strings, String s) {
        return strings.stream().anyMatch(it -> distinguishingSubstring(it).equals(distinguishingSubstring(s)));
    }

    protected String distinguishingSubstring(String scriptName) {
        return scriptName.length() > 3 ? scriptName.substring(3) : scriptName;
    }

    /**
     * Mark all SQL updates scripts as evaluated
     * Try to execute Groovy scripts
     */
    protected void prepareScripts() {
        List<ScriptResource> updateFiles = getUpdateScripts();
        for (ScriptResource file : updateFiles)
            markScript(getScriptName(file), true);
    }

    protected Set<String> getExecutedScripts() {
        QueryRunner runner = new QueryRunner(getDataSource());
        try {
            //noinspection UnnecessaryLocalVariable
            Set<String> scripts = runner.query("select SCRIPT_NAME from SYS_DB_CHANGELOG", rs -> {
                Set<String> rows = new HashSet<>();
                while (rs.next()) {
                    rows.add(rs.getString(1));
                }
                return rows;
            });
            log.trace("Found executed scripts: {}", scripts);
            return scripts;
        } catch (SQLException e) {
            throw new RuntimeException(ERROR + "Error loading executed scripts", e);
        }
    }

    protected void markScript(String name, boolean init) {
        log.trace("Marking script as executed: {}", name);
        QueryRunner runner = new QueryRunner(getDataSource());
        try {
            runner.update("insert into SYS_DB_CHANGELOG (SCRIPT_NAME, IS_INIT) values (?, ?)",
                    new Object[] { name, init ? 1 : 0 });
        } catch (SQLException e) {
            throw new RuntimeException(ERROR + "Error updating SYS_DB_CHANGELOG", e);
        }
    }

    protected boolean isEmpty(String sql) {
        String[] lines = sql.split("\\r?\\n");
        for (String line : lines) {
            line = line.trim();
            if (!line.startsWith(SQL_COMMENT_PREFIX) && !StringUtils.isBlank(line)) {
                return false;
            }
        }
        return true;
    }

    protected boolean executeSqlScript(ScriptResource file) {
        String script;
        try {
            script = file.getContent();
        } catch (IOException e) {
            throw new RuntimeException(ERROR + "Error resolving SQL script " + file.name, e);
        }
        ScriptSplitter splitter = new ScriptSplitter(SQL_DELIMITER);
        for (String sql : splitter.split(script)) {
            if (!isEmpty(sql)) {
                log.debug("Executing SQL:\n" + sql);
                try {
                    executeSql(sql);
                } catch (SQLException e) {
                    throw new RuntimeException(
                            ERROR + "Error executing SQL script " + file.name + "\n" + e.getMessage(), e);
                }
            }
        }
        return true;
    }

    protected void executeSql(String sql) throws SQLException {
        Connection connection = null;
        Statement statement = null;
        try {
            connection = getDataSource().getConnection();
            statement = connection.createStatement();
            statement.execute(sql);
        } finally {
            DbUtils.closeQuietly(statement);
            DbUtils.closeQuietly(connection);
        }
    }

    protected boolean executeGroovyScript(ScriptResource file) {
        try {
            ClassLoader classLoader = getClass().getClassLoader();
            CompilerConfiguration cc = new CompilerConfiguration();
            cc.setRecompileGroovySource(true);

            Binding bind = new Binding();
            bind.setProperty("ds", getDataSource());
            bind.setProperty("log", LoggerFactory.getLogger(String.format("%s$%s", DbUpdaterEngine.class.getName(),
                    StringUtils.removeEndIgnoreCase(file.getName(), ".groovy"))));
            if (!StringUtils.endsWithIgnoreCase(file.getName(), "." + UPGRADE_GROOVY_EXTENSION)) {
                bind.setProperty("postUpdate", new PostUpdateScripts() {
                    @Override
                    public void add(Closure closure) {
                        super.add(closure);

                        log.warn("Added post update action will be ignored");
                    }
                });
            }

            GroovyShell shell = new GroovyShell(classLoader, bind, cc);
            //noinspection UnnecessaryLocalVariable
            Script script = shell.parse(file.getContent());
            script.run();
        } catch (Exception e) {
            throw new RuntimeException(ERROR + "Error executing Groovy script " + file.name + "\n" + e.getMessage(),
                    e);
        }
        return true;
    }

    protected boolean executeScript(ScriptResource file) {
        log.info("Executing script " + getScriptName(file));
        String filename = file.getName();
        String extension = FilenameUtils.getExtension(filename);
        if (StringUtils.isNotEmpty(extension)) {
            if (extensionHandlers.containsKey(extension)) {
                FileHandler handler = extensionHandlers.get(extension);
                return handler.run(file);
            } else
                log.warn("Update script ignored, file handler for extension not found:" + file.getName());
        } else
            log.warn("Update script ignored, file extension undefined:" + file.getName());
        return false;
    }

    protected List<String> getModuleDirs() {
        return scriptScanner().getModuleDirs();
    }

    protected List<ScriptResource> getInitScripts() {
        return getModuleDirs().stream().flatMap(name -> getInitScripts(name).stream()).collect(Collectors.toList());
    }

    protected List<ScriptResource> getInitScripts(@Nullable String oneModuleDir) {
        return getScripts(ScriptType.INIT, oneModuleDir);
    }

    protected List<ScriptResource> getUpdateScripts() {
        return getModuleDirs().stream().flatMap(name -> getUpdateScripts(name).stream())
                .collect(Collectors.toList());
    }

    protected List<ScriptResource> getUpdateScripts(@Nullable String oneModuleDir) {
        return getScripts(ScriptType.UPDATE, oneModuleDir);
    }

    protected String getConnectionUrl(Connection connection) {
        try {
            DatabaseMetaData databaseMetaData = connection.getMetaData();
            return databaseMetaData.getURL();
        } catch (Throwable e) {
            log.warn("Unable to get connection url");
            return null;
        }
    }

    // File extension handler
    public interface FileHandler {
        /**
         * @return need mark as executed or not
         */
        boolean run(ScriptResource file);
    }

    public static class ScriptSplitter {

        private String delimiter;

        public ScriptSplitter(String delimiter) {
            this.delimiter = delimiter;
        }

        public List<String> split(java.lang.String script) {
            String qd = Pattern.quote(delimiter);
            String[] commands = script.split("(?<!" + qd + ")" + qd + "(?!" + qd + ")"); // regex for ^: (?<!\^)\^(?!\^)
            return Arrays.asList(commands).stream().map(s -> s.replace(delimiter + delimiter, delimiter))
                    .collect(Collectors.toList());
        }
    }
}