Java tutorial
/* * Copyright 2014-2016 the original author or authors * * Licensed under the Apache License, Version 2.0 (the License?); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS? BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.openddal.test; import static junit.framework.Assert.assertFalse; import static junit.framework.Assert.assertTrue; import static junit.framework.Assert.fail; import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.io.Reader; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Proxy; import java.nio.channels.FileChannel; import java.nio.channels.FileLock; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.LinkedList; import java.util.SimpleTimeZone; import java.util.UUID; import javax.sql.DataSource; import org.apache.commons.dbcp.BasicDataSource; import org.junit.After; import com.openddal.message.DbException; import com.openddal.test.utils.ProxyCodeGenerator; import com.openddal.test.utils.ResultVerifier; import com.openddal.util.FilePath; import com.openddal.util.JdbcUtils; import com.openddal.util.MurmurHash; import junit.framework.Assert; public abstract class BaseTestCase { /** * The base directory. */ public static final String BASE_TEST_DIR = "./data"; /** * The base directory. */ public static final String BASE_TEST_CONFIGLOCATION = "/config/ddal-config.xml"; static { System.setProperty("ddal.engineConfigLocation", BASE_TEST_CONFIGLOCATION); } /** * An id used to create unique file names. */ protected static int uniqueId; /** * The last time something was printed. */ private static long lastPrint; private final LinkedList<byte[]> memory = new LinkedList<byte[]>(); /** * The time when the test was started. */ protected long start; protected DataSource dataSource; protected static String url = "jdbc:openddal:config/ddal-config.xml;"; protected static String driverClassName = "com.openddal.jdbc.Driver"; public BaseTestCase() { BasicDataSource ds = newDataSource(); this.dataSource = ds; } private BasicDataSource newDataSource() { BasicDataSource ds = new BasicDataSource(); ds.setDriverClassName(driverClassName); ds.setUrl(url); ds.setDefaultTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); return ds; } public BaseTestCase(DataSource dataSource) { this.dataSource = dataSource; } /** * @param dataSource the dataSource to set */ public void setDataSource(DataSource dataSource) { this.dataSource = dataSource; } /** * Get the number of megabytes heap memory in use. * * @return the used megabytes */ public static int getMemoryUsed() { return (int) (getMemoryUsedBytes() / 1024 / 1024); } /** * Get the number of bytes heap memory in use. * * @return the used bytes */ public static long getMemoryUsedBytes() { Runtime rt = Runtime.getRuntime(); long memory = Long.MAX_VALUE; for (int i = 0; i < 8; i++) { rt.gc(); long memNow = rt.totalMemory() - rt.freeMemory(); if (memNow >= memory) { break; } memory = memNow; } return memory; } /** * Log an error message. * * @param s the message * @param e the exception */ public static void logError(String s, Throwable e) { if (e == null) { e = new Exception(s); } System.out.flush(); System.err.println("ERROR: " + s + " " + e.toString() + " ------------------------------"); e.printStackTrace(); // synchronize on this class, because file locks are only visible to // other JVMs synchronized (BaseTestCase.class) { try { // lock FileChannel fc = FilePath.get("error.lock").open("rw"); FileLock lock; while (true) { lock = fc.tryLock(); if (lock != null) { break; } Thread.sleep(10); } // append FileWriter fw = new FileWriter("error.txt", true); PrintWriter pw = new PrintWriter(fw); e.printStackTrace(pw); pw.close(); fw.close(); // unlock lock.release(); } catch (Throwable t) { t.printStackTrace(); } } System.err.flush(); } /** * Print a message, prepended with the specified time in milliseconds. * * @param millis the time in milliseconds * @param s the message */ static synchronized void printlnWithTime(long millis, String s) { SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss"); s = dateFormat.format(new java.util.Date()) + " " + formatTime(millis) + " " + s; System.out.println(s); } /** * Format the time in the format hh:mm:ss.1234 where 1234 is milliseconds. * * @param millis the time in milliseconds * @return the formatted time */ static String formatTime(long millis) { String s = new java.sql.Time(java.sql.Time.valueOf("0:0:0").getTime() + millis).toString() + "." + ("" + (1000 + (millis % 1000))).substring(1); if (s.startsWith("00:")) { s = s.substring(3); } return s; } private static boolean testRow(String[] a, String[] b, int len) { for (int i = 0; i < len; i++) { String sa = a[i]; String sb = b[i]; if (sa == null || sb == null) { if (sa != sb) { return false; } } else { if (!sa.equals(sb)) { return false; } } } return true; } private static String[] getData(ResultSet rs, int len) throws SQLException { String[] data = new String[len]; for (int i = 0; i < len; i++) { data[i] = rs.getString(i + 1); // just check if it works rs.getObject(i + 1); } return data; } private static String formatRow(String[] row) { String sb = ""; for (String r : row) { sb += "{" + r + "}"; } return "{" + sb + "}"; } private static String removeRowCount(String scriptLine) { int index = scriptLine.indexOf("+/-"); if (index >= 0) { scriptLine = scriptLine.substring(index); } return scriptLine; } /** * Construct a stream of 20 KB that fails while reading with the provided * exception. * * @param e the exception * @return the stream */ public static ByteArrayInputStream createFailingStream(final Exception e) { return new ByteArrayInputStream(new byte[20 * 1024]) { @Override public int read(byte[] buffer, int off, int len) { if (this.pos > 10 * 1024) { throwException(e); } return super.read(buffer, off, len); } }; } /** * Throw a checked exception, without having to declare the method as * throwing a checked exception. * * @param e the exception to throw */ public static void throwException(Throwable e) { BaseTestCase.<RuntimeException>throwThis(e); } @SuppressWarnings("unchecked") private static <E extends Throwable> void throwThis(Throwable e) throws E { throw (E) e; } @After public void destory() { try { } catch (Exception e) { e.printStackTrace(); } } public void close(Connection connection, Statement statement, ResultSet rs) { if (rs != null) { try { if (!rs.isClosed()) rs.close(); } catch (SQLException e) { e.printStackTrace(); } } if (statement != null) { try { if (!statement.isClosed()) statement.close(); } catch (SQLException e) { e.printStackTrace(); } } if (connection != null) { try { if (!connection.isClosed()) connection.close(); } catch (SQLException e) { e.printStackTrace(); } } } public long getUUID() { UUID uuid = UUID.randomUUID(); long murmurhash2_64 = MurmurHash.hash64(uuid.toString()); murmurhash2_64 = Math.abs(murmurhash2_64); return murmurhash2_64; } public int nextOrderSeqVaule(Connection conn) throws SQLException { PreparedStatement ptmt = null; ResultSet rs = null; String sql = "SELECT nextval('order_seq')"; try { conn = this.dataSource.getConnection(); ptmt = conn.prepareStatement(sql); rs = ptmt.executeQuery(); rs.next(); return rs.getInt(1); } catch (SQLException e) { throw e; } finally { close(null, ptmt, rs); } } /** * Open a database connection in admin mode. The default user name and * password is used. * * @return the connection */ public Connection getConnection() throws SQLException { return dataSource.getConnection(); } /** * Write a message to system out if trace is enabled. * * @param x the value to write */ protected void trace(int x) { trace("" + x); } /** * Write a message to system out if trace is enabled. * * @param s the message to write */ public void trace(String s) { lastPrint = 0; println(s); } /** * Print how much memory is currently used. */ protected void traceMemory() { trace("mem=" + getMemoryUsed()); } /** * Print the currently used memory, the message and the given time in * milliseconds. * * @param s the message * @param time the time in millis */ public void printTimeMemory(String s, long time) { println(getMemoryUsed() + " MB: " + s + " ms: " + time); } /** * Print a message to system out. * * @param s the message */ public void println(String s) { long now = System.currentTimeMillis(); if (now > lastPrint + 1000) { lastPrint = now; long time = now - start; printlnWithTime(time, getClass().getName() + " " + s); } } /** * Print the current time and a message to system out. * * @param s the message */ protected void printTime(String s) { SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss"); println(dateFormat.format(new java.util.Date()) + " " + s); } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ public void assertEquals(byte[] expected, byte[] actual) { if (expected == null || actual == null) { assertTrue(expected == actual); return; } assertEquals(expected.length, actual.length); for (int i = 0; i < expected.length; i++) { if (expected[i] != actual[i]) { fail("[" + i + "]: expected: " + (int) expected[i] + " actual: " + (int) actual[i]); } } } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ public void assertEquals(java.util.Date expected, java.util.Date actual) { if (expected != actual && !expected.equals(actual)) { DateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); SimpleTimeZone gmt = new SimpleTimeZone(0, "Z"); df.setTimeZone(gmt); fail("Expected: " + df.format(expected) + " actual: " + df.format(actual)); } } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ public void assertEquals(Object[] expected, Object[] actual) { if (expected == null || actual == null) { assertTrue(expected == actual); return; } assertEquals(expected.length, actual.length); for (int i = 0; i < expected.length; i++) { if (expected[i] == null || actual[i] == null) { if (expected[i] != actual[i]) { fail("[" + i + "]: expected: " + expected[i] + " actual: " + actual[i]); } } else if (!expected[i].equals(actual[i])) { fail("[" + i + "]: expected: " + expected[i] + " actual: " + actual[i]); } } } /** * Check if two readers are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @param len the maximum length, or -1 * @throws AssertionError if the values are not equal */ protected void assertEqualReaders(Reader expected, Reader actual, int len) throws IOException { for (int i = 0; len < 0 || i < len; i++) { int ce = expected.read(); int ca = actual.read(); Assert.assertEquals(expected, actual); assertEquals(ce, ca); if (ce == -1) { break; } } expected.close(); actual.close(); } /** * Check if two streams are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @param len the maximum length, or -1 * @throws AssertionError if the values are not equal */ protected void assertEqualStreams(InputStream expected, InputStream actual, int len) throws IOException { // this doesn't actually read anything - just tests reading 0 bytes actual.read(new byte[0]); expected.read(new byte[0]); actual.read(new byte[10], 3, 0); expected.read(new byte[10], 0, 0); for (int i = 0; len < 0 || i < len; i++) { int ca = actual.read(); actual.read(new byte[0]); int ce = expected.read(); if (ca != ce) { assertEquals(ce, ca); } if (ca == -1) { break; } } actual.read(new byte[10], 3, 0); expected.read(new byte[10], 0, 0); actual.read(new byte[0]); expected.read(new byte[0]); actual.close(); expected.close(); } /** * Check if two result sets are equal, and if not throw an exception. * * @param message the message to use if the check fails * @param rs0 the first result set * @param rs1 the second result set * @throws AssertionError if the values are not equal */ protected void assertEquals(String message, ResultSet rs0, ResultSet rs1) throws SQLException { ResultSetMetaData meta = rs0.getMetaData(); int columns = meta.getColumnCount(); assertEquals(columns, rs1.getMetaData().getColumnCount()); while (rs0.next()) { assertTrue(message, rs1.next()); for (int i = 0; i < columns; i++) { Assert.assertEquals(rs0.getString(i + 1), rs1.getString(i + 1)); } } assertFalse(message, rs0.next()); assertFalse(message, rs1.next()); } /** * Check if the first value is larger or equal than the second value, and if * not throw an exception. * * @param a the first value * @param b the second value (must be smaller than the first value) * @throws AssertionError if the first value is smaller */ protected void assertSmaller(long a, long b) { if (a >= b) { fail("a: " + a + " is not smaller than b: " + b); } } /** * Check that a result contains the given substring. * * @param result the result value * @param contains the term that should appear in the result * @throws AssertionError if the term was not found */ protected void assertContains(String result, String contains) { if (result.indexOf(contains) < 0) { fail(result + " does not contain: " + contains); } } /** * Check that a text starts with the expected characters.. * * @param text the text * @param expectedStart the expected prefix * @throws AssertionError if the text does not start with the expected * characters */ protected void assertStartsWith(String text, String expectedStart) { if (!text.startsWith(expectedStart)) { fail("[" + text + "] does not start with: [" + expectedStart + "]"); } } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ protected void assertEquals(double expected, double actual) { if (expected != actual) { if (Double.isNaN(expected) && Double.isNaN(actual)) { // if both a NaN, then there is no error } else { fail("Expected: " + expected + " actual: " + actual); } } } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ protected void assertEquals(float expected, float actual) { if (expected != actual) { if (Float.isNaN(expected) && Float.isNaN(actual)) { // if both a NaN, then there is no error } else { fail("Expected: " + expected + " actual: " + actual); } } } /** * Check that the result set row count matches. * * @param expected the number of expected rows * @param rs the result set * @throws AssertionError if a different number of rows have been found */ protected void assertResultRowCount(int expected, ResultSet rs) throws SQLException { int i = 0; while (rs.next()) { i++; } assertEquals(expected, i); } /** * Check that the result set of a query is exactly this value. * * @param stat the statement * @param sql the SQL statement to execute * @param expected the expected result value * @throws AssertionError if a different result value was returned */ protected void assertSingleValue(Statement stat, String sql, int expected) throws SQLException { ResultSet rs = stat.executeQuery(sql); assertTrue(rs.next()); assertEquals(expected, rs.getInt(1)); assertFalse(rs.next()); } /** * Check that the result set of a query is exactly this value. * * @param expected the expected result value * @param stat the statement * @param sql the SQL statement to execute * @throws AssertionError if a different result value was returned */ protected void assertResult(String expected, Statement stat, String sql) throws SQLException { ResultSet rs = stat.executeQuery(sql); if (rs.next()) { String actual = rs.getString(1); Assert.assertEquals(expected, actual); } else { Assert.assertEquals(expected, null); } } /** * Check that executing the specified query results in the specified error. * * @param expectedErrorCode the expected error code * @param stat the statement * @param sql the SQL statement to execute */ protected void assertThrows(int expectedErrorCode, Statement stat, String sql) { try { stat.execute(sql); fail("Expected error: " + expectedErrorCode); } catch (SQLException ex) { assertEquals(expectedErrorCode, ex.getErrorCode()); } } /** * Check if the result set meta data is correct. * * @param rs the result set * @param columnCount the expected column count * @param labels the expected column labels * @param datatypes the expected data types * @param precision the expected precisions * @param scale the expected scales */ protected void assertResultSetMeta(ResultSet rs, int columnCount, String[] labels, int[] datatypes, int[] precision, int[] scale) throws SQLException { ResultSetMetaData meta = rs.getMetaData(); int cc = meta.getColumnCount(); if (cc != columnCount) { fail("result set contains " + cc + " columns not " + columnCount); } for (int i = 0; i < columnCount; i++) { if (labels != null) { String l = meta.getColumnLabel(i + 1); if (!labels[i].equals(l)) { fail("column label " + i + " is " + l + " not " + labels[i]); } } if (datatypes != null) { int t = meta.getColumnType(i + 1); if (datatypes[i] != t) { fail("column datatype " + i + " is " + t + " not " + datatypes[i] + " (prec=" + meta.getPrecision(i + 1) + " scale=" + meta.getScale(i + 1) + ")"); } String typeName = meta.getColumnTypeName(i + 1); String className = meta.getColumnClassName(i + 1); switch (t) { case Types.INTEGER: Assert.assertEquals("INTEGER", typeName); Assert.assertEquals("java.lang.Integer", className); break; case Types.VARCHAR: Assert.assertEquals("VARCHAR", typeName); Assert.assertEquals("java.lang.String", className); break; case Types.SMALLINT: Assert.assertEquals("SMALLINT", typeName); Assert.assertEquals("java.lang.Short", className); break; case Types.TIMESTAMP: Assert.assertEquals("TIMESTAMP", typeName); Assert.assertEquals("java.sql.Timestamp", className); break; case Types.DECIMAL: Assert.assertEquals("DECIMAL", typeName); Assert.assertEquals("java.math.BigDecimal", className); break; default: } } if (precision != null) { int p = meta.getPrecision(i + 1); if (precision[i] != p) { fail("column precision " + i + " is " + p + " not " + precision[i]); } } if (scale != null) { int s = meta.getScale(i + 1); if (scale[i] != s) { fail("column scale " + i + " is " + s + " not " + scale[i]); } } } } /** * Check if a result set contains the expected data. The sort order is * significant * * @param rs the result set * @param data the expected data * @throws AssertionError if there is a mismatch */ protected void assertResultSetOrdered(ResultSet rs, String[][] data) throws SQLException { assertResultSet(true, rs, data); } /** * Check if a result set contains the expected data. * * @param ordered if the sort order is significant * @param rs the result set * @param data the expected data * @throws AssertionError if there is a mismatch */ private void assertResultSet(boolean ordered, ResultSet rs, String[][] data) throws SQLException { int len = rs.getMetaData().getColumnCount(); int rows = data.length; if (rows == 0) { // special case: no rows if (rs.next()) { fail("testResultSet expected rowCount:" + rows + " got:0"); } } int len2 = data[0].length; if (len < len2) { fail("testResultSet expected columnCount:" + len2 + " got:" + len); } for (int i = 0; i < rows; i++) { if (!rs.next()) { fail("testResultSet expected rowCount:" + rows + " got:" + i); } String[] row = getData(rs, len); if (ordered) { String[] good = data[i]; if (!testRow(good, row, good.length)) { fail("testResultSet row not equal, got:\n" + formatRow(row) + "\n" + formatRow(good)); } } else { boolean found = false; for (int j = 0; j < rows; j++) { String[] good = data[i]; if (testRow(good, row, good.length)) { found = true; break; } } if (!found) { fail("testResultSet no match for row:" + formatRow(row)); } } } if (rs.next()) { String[] row = getData(rs, len); fail("testResultSet expected rowcount:" + rows + " got:>=" + (rows + 1) + " data:" + formatRow(row)); } } /** * Simulate a database crash. This method will also close the database * files, but the files are in a state as the power was switched off. It * doesn't throw an exception. * * @param conn the database connection */ protected void crash(Connection conn) { try { conn.createStatement().execute("SET WRITE_DELAY 0"); conn.createStatement().execute("CREATE TABLE TEST_A(ID INT)"); fail("should be crashed already"); } catch (SQLException e) { // expected } try { conn.close(); } catch (SQLException e) { // ignore } } /** * Read a string from the reader. This method reads until end of file. * * @param reader the reader * @return the string read */ protected String readString(Reader reader) { if (reader == null) { return null; } StringBuilder buffer = new StringBuilder(); try { while (true) { int c = reader.read(); if (c == -1) { break; } buffer.append((char) c); } return buffer.toString(); } catch (Exception e) { assertTrue(false); return null; } } /** * Check that a given exception is not an unexpected 'general error' * exception. * * @param e the error */ public void assertKnownException(SQLException e) { assertKnownException("", e); } /** * Check that a given exception is not an unexpected 'general error' * exception. * * @param message the message * @param e the exception */ protected void assertKnownException(String message, SQLException e) { if (e != null && e.getSQLState().startsWith("HY000")) { BaseTestCase.logError("Unexpected General error " + message, e); } } /** * Check if two values are equal, and if not throw an exception. * * @param expected the expected value * @param actual the actual value * @throws AssertionError if the values are not equal */ protected void assertEquals(Integer expected, Integer actual) { if (expected == null || actual == null) { if (expected != actual) { Assert.assertEquals("" + expected, "" + actual); } } else { assertEquals(expected.intValue(), actual.intValue()); } } /** * Check if two databases contain the same met data. * * @param stat1 the connection to the first database * @param stat2 the connection to the second database * @throws AssertionError if the databases don't match */ protected void assertEqualDatabases(Statement stat1, Statement stat2) throws SQLException { ResultSet rs = stat1 .executeQuery("select value from information_schema.settings " + "where name='ANALYZE_AUTO'"); int analyzeAuto = rs.next() ? rs.getInt(1) : 0; if (analyzeAuto > 0) { stat1.execute("analyze"); stat2.execute("analyze"); } ResultSet rs1 = stat1.executeQuery("SCRIPT simple NOPASSWORDS"); ResultSet rs2 = stat2.executeQuery("SCRIPT simple NOPASSWORDS"); ArrayList<String> list1 = new ArrayList<String>(); ArrayList<String> list2 = new ArrayList<String>(); while (rs1.next()) { String s1 = rs1.getString(1); s1 = removeRowCount(s1); if (!rs2.next()) { fail("expected: " + s1); } String s2 = rs2.getString(1); s2 = removeRowCount(s2); if (!s1.equals(s2)) { list1.add(s1); list2.add(s2); } } for (String s : list1) { if (!list2.remove(s)) { fail("only found in first: " + s + " remaining: " + list2); } } Assert.assertEquals("remaining: " + list2, 0, list2.size()); assertFalse(rs2.next()); } /** * Get the classpath list used to execute java -cp ... * * @return the classpath list */ protected String getClassPath() { return "bin" + File.pathSeparator + "temp" + File.pathSeparator + "."; } /** * Use up almost all memory. * * @param remainingKB the number of kilobytes that are not referenced */ protected void eatMemory(int remainingKB) { byte[] reserve = new byte[remainingKB * 1024]; // first, eat memory in 16 KB blocks, then eat in 16 byte blocks for (int size = 16 * 1024; size > 0; size /= 1024) { while (true) { try { byte[] block = new byte[16 * 1024]; memory.add(block); } catch (OutOfMemoryError e) { break; } } } // silly code - makes sure there are no warnings reserve[0] = reserve[1]; } /** * Remove the hard reference to the memory. */ protected void freeMemory() { memory.clear(); } /** * Verify the next method call on the object will throw an exception. * * @param <T> the class of the object * @param expectedExceptionClass the expected exception class to be thrown * @param obj the object to wrap * @return a proxy for the object */ protected <T> T assertThrows(final Class<?> expectedExceptionClass, final T obj) { return assertThrows(new ResultVerifier() { @Override public boolean verify(Object returnValue, Throwable t, Method m, Object... args) { if (t == null) { throw new AssertionError("Expected an exception of type " + expectedExceptionClass.getSimpleName() + " to be thrown, but the method returned " + returnValue + " for " + ProxyCodeGenerator.formatMethodCall(m, args)); } if (!expectedExceptionClass.isAssignableFrom(t.getClass())) { AssertionError ae = new AssertionError( "Expected an exception of type\n" + expectedExceptionClass.getSimpleName() + " to be thrown, but the method under test " + "threw an exception of type\n" + t.getClass().getSimpleName() + " (see in the 'Caused by' for the exception " + "that was thrown) " + " for " + ProxyCodeGenerator.formatMethodCall(m, args)); ae.initCause(t); throw ae; } return false; } }, obj); } /** * Verify the next method call on the object will throw an exception. * * @param <T> the class of the object * @param expectedErrorCode the expected error code * @param obj the object to wrap * @return a proxy for the object */ protected <T> T assertThrows(final int expectedErrorCode, final T obj) { return assertThrows(new ResultVerifier() { @Override public boolean verify(Object returnValue, Throwable t, Method m, Object... args) { int errorCode; if (t instanceof DbException) { errorCode = ((DbException) t).getErrorCode(); } else if (t instanceof SQLException) { errorCode = ((SQLException) t).getErrorCode(); } else { errorCode = 0; } if (errorCode != expectedErrorCode) { AssertionError ae = new AssertionError( "Expected an SQLException or DbException with error code " + expectedErrorCode); ae.initCause(t); throw ae; } return false; } }, obj); } /** * Verify the next method call on the object will throw an exception. * * @param <T> the class of the object * @param verifier the result verifier to call * @param obj the object to wrap * @return a proxy for the object */ @SuppressWarnings("unchecked") protected <T> T assertThrows(final ResultVerifier verifier, final T obj) { Class<?> c = obj.getClass(); InvocationHandler ih = new InvocationHandler() { private Exception called = new Exception("No method called"); @Override protected void finalize() { if (called != null) { called.printStackTrace(System.err); } } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Exception { try { called = null; Object ret = method.invoke(obj, args); verifier.verify(ret, null, method, args); return ret; } catch (InvocationTargetException e) { verifier.verify(null, e.getTargetException(), method, args); Class<?> retClass = method.getReturnType(); if (!retClass.isPrimitive()) { return null; } if (retClass == boolean.class) { return false; } else if (retClass == byte.class) { return (byte) 0; } else if (retClass == char.class) { return (char) 0; } else if (retClass == short.class) { return (short) 0; } else if (retClass == int.class) { return 0; } else if (retClass == long.class) { return 0L; } else if (retClass == float.class) { return 0F; } else if (retClass == double.class) { return 0D; } return null; } } }; if (!ProxyCodeGenerator.isGenerated(c)) { Class<?>[] interfaces = c.getInterfaces(); if (Modifier.isFinal(c.getModifiers()) || (interfaces.length > 0 && getClass() != c)) { // interface class proxies if (interfaces.length == 0) { throw new RuntimeException("Can not create a proxy for the class " + c.getSimpleName() + " because it doesn't implement any interfaces and is final"); } return (T) Proxy.newProxyInstance(c.getClassLoader(), interfaces, ih); } } try { Class<?> pc = ProxyCodeGenerator.getClassProxy(c); Constructor<?> cons = pc.getConstructor(InvocationHandler.class); return (T) cons.newInstance(ih); } catch (Exception e) { throw new RuntimeException(e); } } /** * Create a proxy class that extends the given class. * * @param clazz the class */ protected void createClassProxy(Class<?> clazz) { try { ProxyCodeGenerator.getClassProxy(clazz); } catch (Exception e) { throw new RuntimeException(e); } } public static void printResultSet(ResultSet rs) { try { if (rs != null) { ResultSetMetaData md = rs.getMetaData(); int cols = md.getColumnCount(); StringBuffer sb = new StringBuffer(); for (int i = 0; i < cols; i++) { sb.append(md.getColumnName(i + 1) + " "); } sb.append('\n'); for (int i = 0; i < cols; i++) { sb.append("-------------------"); } sb.append('\n'); while (rs.next()) { for (int i = 0; i < cols; i++) { sb.append(rs.getString(i + 1) + " "); } } sb.append("\n"); System.out.println(sb.toString()); } } catch (Exception e) { e.printStackTrace(); } } protected String getTestName() { return getClass().getSimpleName(); } public void dropTable(String name) throws SQLException { Connection connection = getConnection(); Statement stm = connection.createStatement(); try { stm.executeUpdate("drop table " + name); } catch (SQLException sqle) { // assume the table didn't exist } finally { JdbcUtils.closeSilently(connection); JdbcUtils.closeSilently(stm); } } public Connection getH2Connection() { String url = "jdbc:h2:~/testdb"; Connection conn = null; try { conn = DriverManager.getConnection(url, System.getProperties()); } catch (SQLException e) { throw new IllegalArgumentException(); } return conn; } }