org.jumpmind.symmetric.io.AbstractWriterTest.java Source code

Java tutorial

Introduction

Here is the source code for org.jumpmind.symmetric.io.AbstractWriterTest.java

Source

/**
 * Licensed to JumpMind Inc under one or more contributor
 * license agreements.  See the NOTICE file distributed
 * with this work for additional information regarding
 * copyright ownership.  JumpMind Inc licenses this file
 * to you under the GNU General Public License, version 3.0 (GPLv3)
 * (the "License"); you may not use this file except in compliance
 * with the License.
 *
 * You should have received a copy of the GNU General Public License,
 * version 3.0 (GPLv3) along with this library; if not, see
 * <http://www.gnu.org/licenses/>.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.jumpmind.symmetric.io;

import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.jumpmind.db.model.Table;
import org.jumpmind.db.platform.AbstractDatabasePlatform;
import org.jumpmind.db.platform.DatabaseNamesConstants;
import org.jumpmind.db.platform.IDatabasePlatform;
import org.jumpmind.db.platform.sqlite.SqliteDatabasePlatform;
import org.jumpmind.db.util.BinaryEncoding;
import org.jumpmind.symmetric.io.data.Batch;
import org.jumpmind.symmetric.io.data.Batch.BatchType;
import org.jumpmind.symmetric.io.data.writer.DataWriterStatisticConstants;
import org.jumpmind.symmetric.io.data.writer.DatabaseWriterSettings;
import org.jumpmind.symmetric.io.data.writer.DefaultDatabaseWriter;
import org.jumpmind.symmetric.io.data.writer.IgnoreBatchException;
import org.jumpmind.symmetric.io.data.CsvData;
import org.jumpmind.symmetric.io.data.DataContext;
import org.jumpmind.symmetric.io.data.IDataWriter;
import org.jumpmind.util.Statistics;
import org.junit.Assert;

abstract public class AbstractWriterTest {

    protected static IDatabasePlatform platform;

    protected boolean errorExpected = true;

    protected final static String TEST_TABLE = "test_dataloader_table";

    protected final static String[] TEST_KEYS = { "id" };

    protected final static String[] TEST_COLUMNS = { "id", "string_value", "string_required_value", "char_value",
            "char_required_value", "date_value", "time_value", "boolean_value", "integer_value", "decimal_value",
            "double_value" };

    protected static long batchId = 10000;

    protected static long sequenceId = 10000;

    protected DatabaseWriterSettings writerSettings = new DatabaseWriterSettings();

    protected IDataWriter lastDataWriterUsed;

    protected synchronized long getNextBatchId() {
        return ++batchId;
    }

    protected synchronized long getBatchId() {
        return batchId;
    }

    protected synchronized String getNextId() {
        return String.valueOf(++sequenceId);
    }

    protected synchronized String getId() {
        return String.valueOf(sequenceId);
    }

    protected Table buildSourceTable(String tableName, String[] keyNames, String[] columnNames) {
        return Table.buildTable(tableName, keyNames, columnNames);
    }

    protected void writeData(CsvData data, String[] expectedValues) {
        writeData(data, expectedValues, TEST_COLUMNS);
    }

    protected String getTestTable() {
        return TEST_TABLE;
    }

    protected void writeData(CsvData... data) {
        Table table = buildSourceTable(TEST_TABLE, TEST_KEYS, TEST_COLUMNS);
        writeData(new TableCsvData(table, data));
    }

    protected void writeData(CsvData data, String[] expectedValues, String[] columnNames) {
        writeData(data, expectedValues, getTestTable(), TEST_KEYS, columnNames);
    }

    protected void writeData(CsvData data, String[] expectedValues, String tableName, String[] keyNames,
            String[] columnNames) {
        Table table = buildSourceTable(tableName, keyNames, columnNames);
        writeData(new TableCsvData(table, data));
        String[] pkData = data.getParsedData(CsvData.ROW_DATA);
        if (pkData == null) {
            pkData = data.getParsedData(CsvData.PK_DATA);
        }
        assertTestTableEquals(pkData[0], expectedValues);
    }

    protected long writeData(TableCsvData... datas) {
        return writeData(new DefaultDatabaseWriter(platform, writerSettings), datas);
    }

    protected long writeData(IDataWriter writer, TableCsvData... datas) {
        this.lastDataWriterUsed = writer;
        DataContext context = new DataContext();
        writer.open(context);
        try {
            for (TableCsvData tableCsvData : datas) {
                Batch batch = new Batch(BatchType.LOAD, getNextBatchId(), "default", BinaryEncoding.BASE64, "00000",
                        "00001", false);
                try {
                    writer.start(batch);
                    if (writer.start(tableCsvData.table)) {
                        for (CsvData d : tableCsvData.data) {
                            writer.write(d);
                        }
                        writer.end(tableCsvData.table);
                    }
                    writer.end(batch, false);
                } catch (IgnoreBatchException ex) {
                    writer.end(batch, false);
                } catch (Exception ex) {
                    writer.end(batch, true);
                    if (!isErrorExpected()) {
                        if (ex instanceof RuntimeException) {
                            throw (RuntimeException) ex;
                        } else {
                            throw new RuntimeException(ex);
                        }
                    }

                }

            }
        } finally {
            writer.close();
        }

        long statementCount = 0;
        Collection<Statistics> stats = writer.getStatistics().values();
        for (Statistics statistics : stats) {
            statementCount += statistics.get(DataWriterStatisticConstants.STATEMENTCOUNT);
        }
        return statementCount;
    }

    protected void assertTestTableEquals(String testTableId, String[] expectedValues) {
        String sql = "select " + getSelect(TEST_COLUMNS) + " from " + getTestTable() + " where "
                + getWhere(TEST_KEYS);
        Map<String, Object> results = platform.getSqlTemplate().queryForMap(sql, new Long(testTableId));

        if (expectedValues != null) {
            expectedValues[1] = translateExpectedString(expectedValues[1], false);
            expectedValues[2] = translateExpectedString(expectedValues[2], true);
            expectedValues[3] = translateExpectedCharString(expectedValues[3], 50, false);
            expectedValues[4] = translateExpectedCharString(expectedValues[4], 50, true);
        }
        assertEquals(TEST_COLUMNS, expectedValues, results);
    }

    protected String getSelect(String[] columns) {
        StringBuilder str = new StringBuilder();
        for (int i = 0; i < columns.length; i++) {
            str.append(columns[i]).append(i + 1 < columns.length ? ", " : "");
        }
        return str.toString();
    }

    protected String getWhere(String[] columns) {
        StringBuilder str = new StringBuilder();
        for (int i = 0; i < columns.length; i++) {
            str.append(columns[i]).append(" = ?").append(i + 1 < columns.length ? "," : "");
        }
        return str.toString();
    }

    protected String translateExpectedString(String value, boolean isRequired) {
        if (isRequired
                && (value == null || (value.equals("") && platform.getDatabaseInfo().isEmptyStringNulled()))) {
            return AbstractDatabasePlatform.REQUIRED_FIELD_NULL_SUBSTITUTE;
        } else if (value != null && value.equals("") && platform.getDatabaseInfo().isEmptyStringNulled()) {
            return null;
        }
        return value;
    }

    protected String translateExpectedCharString(String value, int size, boolean isRequired) {
        if (isRequired && value == null) {
            if (!platform.getDatabaseInfo().isRequiredCharColumnEmptyStringSameAsNull()
                    || platform.getDatabaseInfo().isEmptyStringNulled()) {
                value = AbstractDatabasePlatform.REQUIRED_FIELD_NULL_SUBSTITUTE;
            }
        }
        if (value != null
                && ((StringUtils.isBlank(value) && platform.getDatabaseInfo().isBlankCharColumnSpacePadded())
                        || (StringUtils.isNotBlank(value)
                                && platform.getDatabaseInfo().isNonBlankCharColumnSpacePadded()))) {
            return StringUtils.rightPad(value, size);
        } else if (value != null && platform.getDatabaseInfo().isCharColumnSpaceTrimmed()) {
            return value.replaceFirst(" *$", "");
        }
        return value;
    }

    protected void assertEquals(String[] name, String[] expected, Map<String, Object> results) {
        if (expected == null) {
            Assert.assertNull("Expected empty results. " + printDatabase(), results);
        } else {
            Assert.assertNotNull(String.format("Did not find the expected row: %s.", Arrays.toString(expected)),
                    results);
            for (int i = 0; i < expected.length; i++) {
                Object resultObj = results.get(name[i]);
                String resultValue = null;
                char decimal = ((DecimalFormat) DecimalFormat.getInstance()).getDecimalFormatSymbols()
                        .getDecimalSeparator();
                if ((resultObj instanceof Double || resultObj instanceof BigDecimal)
                        && expected[i].indexOf(decimal) != -1) {
                    DecimalFormat df = new DecimalFormat("0.00####################################");
                    resultValue = df.format(resultObj);
                } else if (resultObj instanceof Date) {
                    SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.000");
                    resultValue = df.format(resultObj);
                } else if (resultObj instanceof Boolean) {
                    resultValue = ((Boolean) resultObj) ? "1" : "0";
                } else if (resultObj instanceof Double) {
                    resultValue = resultObj.toString();
                    if (platform instanceof SqliteDatabasePlatform) {
                        expected[i] = new Double(expected[i]).toString();
                    }

                } else if (resultObj != null) {
                    resultValue = resultObj.toString();
                }

                Assert.assertEquals(name[i] + ". " + printDatabase(), expected[i], resultValue);
            }
        }
    }

    protected String printDatabase() {
        return " The database we are testing against is " + platform.getName() + ".";
    }

    protected boolean isOracle() {
        return DatabaseNamesConstants.ORACLE.equals(platform.getName());
    }

    public void setErrorExpected(boolean errorExpected) {
        this.errorExpected = errorExpected;
    }

    public boolean isErrorExpected() {
        return errorExpected;
    }

    public Map<String, Object> queryForRow(String id) {
        return platform.getSqlTemplate().queryForMap("select * from " + TEST_TABLE + " where id=?",
                new Integer(id));
    }

    protected class TableCsvData {
        Table table;
        List<CsvData> data;

        public TableCsvData(Table table, CsvData... csvDatas) {
            this.table = table;
            this.data = new ArrayList<CsvData>();
            for (CsvData csvData : csvDatas) {
                this.data.add(csvData);
            }
        }

        public TableCsvData(Table table, List<CsvData> data) {
            this.table = table;
            this.data = data;
        }

    }

    protected long countRows(String tableName) {
        return platform.getSqlTemplate().queryForInt(String.format("select count(*) from %s", tableName));
    }

}