com.dangdang.ddframe.rdb.integrate.AbstractDBUnitTest.java Source code

Java tutorial

Introduction

Here is the source code for com.dangdang.ddframe.rdb.integrate.AbstractDBUnitTest.java

Source

/*
 * Copyright 1999-2015 dangdang.com.
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * </p>
 */

package com.dangdang.ddframe.rdb.integrate;

import com.dangdang.ddframe.rdb.sharding.constants.DatabaseType;
import org.apache.commons.dbcp.BasicDataSource;
import org.dbunit.DatabaseUnitException;
import org.dbunit.IDatabaseTester;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITable;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.dbunit.ext.h2.H2Connection;
import org.dbunit.ext.mysql.MySqlConnection;
import org.dbunit.operation.DatabaseOperation;
import org.h2.tools.RunScript;
import org.junit.Before;

import javax.sql.DataSource;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.dbunit.Assertion.assertEquals;

public abstract class AbstractDBUnitTest {

    protected static final DatabaseType CURRENT_DB_TYPE = DatabaseType.H2;

    protected static final Map<String, DataSource> DATA_SOURCES = new HashMap<>();

    private final DataBaseEnvironment dbEnv = new DataBaseEnvironment(CURRENT_DB_TYPE);

    @Before
    public void createSchema() throws SQLException {
        for (String each : getSchemaFiles()) {
            Connection conn = createDataSource(each).getConnection();
            RunScript.execute(conn,
                    new InputStreamReader(AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(each)));
            conn.close();
        }
    }

    @Before
    public final void importDataSet() throws Exception {
        for (String each : getDataSetFiles()) {
            InputStream is = AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(each);
            IDataSet dataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(is));
            IDatabaseTester databaseTester = new ShardingJdbcDatabaseTester(dbEnv.getDriverClassName(),
                    dbEnv.getURL(getFileName(each)), dbEnv.getUsername(), dbEnv.getPassword());
            databaseTester.setSetUpOperation(DatabaseOperation.CLEAN_INSERT);
            databaseTester.setDataSet(dataSet);
            databaseTester.onSetup();
        }
    }

    protected abstract List<String> getSchemaFiles();

    protected abstract List<String> getDataSetFiles();

    protected final Map<String, DataSource> createDataSourceMap(final String dataSourceNamePattern) {
        Map<String, DataSource> result = new HashMap<>(getDataSetFiles().size());
        for (String each : getDataSetFiles()) {
            result.put(String.format(dataSourceNamePattern, getFileName(each)), createDataSource(each));
        }
        return result;
    }

    private DataSource createDataSource(final String dataSetFile) {
        if (DATA_SOURCES.containsKey(dataSetFile)) {
            return DATA_SOURCES.get(dataSetFile);
        }
        BasicDataSource result = new BasicDataSource();
        result.setDriverClassName(dbEnv.getDriverClassName());
        result.setUrl(dbEnv.getURL(getFileName(dataSetFile)));
        result.setUsername(dbEnv.getUsername());
        result.setPassword(dbEnv.getPassword());
        result.setMaxActive(1000);
        DATA_SOURCES.put(dataSetFile, result);
        return result;
    }

    private String getFileName(final String dataSetFile) {
        String fileName = new File(dataSetFile).getName();
        if (-1 == fileName.lastIndexOf(".")) {
            return fileName;
        }
        return fileName.substring(0, fileName.lastIndexOf("."));
    }

    protected void assertDataSet(final String expectedDataSetFile, final Connection connection,
            final String actualTableName, final String sql, final Object... params)
            throws SQLException, DatabaseUnitException {
        try (Connection conn = connection; PreparedStatement ps = conn.prepareStatement(sql)) {
            int i = 1;
            for (Object each : params) {
                ps.setObject(i++, each);
            }
            ITable actualTable = getConnection(connection).createTable(actualTableName, ps);
            IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(
                    AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(expectedDataSetFile)));
            assertEquals(expectedDataSet.getTable(actualTableName), actualTable);
        }
    }

    protected void assertDataSet(final String expectedDataSetFile, final Connection connection,
            final String actualTableName, final String sql) throws SQLException, DatabaseUnitException {
        try (Connection conn = connection) {
            ITable actualTable = getConnection(conn).createQueryTable(actualTableName, sql);
            IDataSet expectedDataSet = new FlatXmlDataSetBuilder().build(new InputStreamReader(
                    AbstractDBUnitTest.class.getClassLoader().getResourceAsStream(expectedDataSetFile)));
            assertEquals(expectedDataSet.getTable(actualTableName), actualTable);
        }
    }

    private IDatabaseConnection getConnection(final Connection connection) throws DatabaseUnitException {
        switch (dbEnv.getDatabaseType()) {
        case H2:
            return new H2Connection(connection, "PUBLIC");
        case MySQL:
            return new MySqlConnection(connection, "PUBLIC");
        default:
            throw new UnsupportedOperationException(dbEnv.getDatabaseType().name());
        }
    }
}