Java tutorial
/* * * The DbUnit Database Testing Framework * Copyright (C)2002-2005, DbUnit.org * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA * */ package org.dbunit.database; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import org.apache.commons.collections.map.ListOrderedMap; import org.dbunit.database.search.ForeignKeyRelationshipEdge; import org.dbunit.dataset.DataSetException; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.ITable; import org.dbunit.dataset.ITableIterator; import org.dbunit.dataset.ITableMetaData; import org.dbunit.dataset.filter.AbstractTableFilter; import org.dbunit.util.SQLHelper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Filter a table given a map of the allowed rows based on primary key values.<br> * It uses a depth-first algorithm (although not recursive - it might be refactored * in the future) to define which rows are allowed, as well which rows are necessary * (and hence allowed) because of dependencies with the allowed rows.<br> * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment. * TODO: test cases * @author Felipe Leme (dbunit@felipeal.net) * @author Last changed by: $Author$ * @version $Revision$ $Date$ * @since Sep 9, 2005 */ public class PrimaryKeyFilter extends AbstractTableFilter { private final IDatabaseConnection connection; private final PkTableMap allowedPKsPerTable; private final PkTableMap allowedPKsInput; private final PkTableMap pksToScanPerTable; private final boolean reverseScan; protected final Logger logger = LoggerFactory.getLogger(getClass()); // cache the primary keys private final Map pkColumnPerTable = new HashMap(); private final Map fkEdgesPerTable = new HashMap(); private final Map fkReverseEdgesPerTable = new HashMap(); // name of the tables, in reverse order of dependency private final List tableNames = new ArrayList(); /** * Default constructor, it takes as input a map with desired rows in a final * dataset; the filter will ensure that the rows necessary by these initial rows * are also allowed (and so on...). * @param connection database connection * @param allowedPKs map of allowed rows, based on the primary keys (key is the name * of a table; value is a Set with allowed primary keys for that table) * @param reverseDependency flag indicating if the rows that depend on a row should * also be allowed by the filter */ public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) { this.connection = connection; this.allowedPKsPerTable = new PkTableMap(); this.allowedPKsInput = allowedPKs; this.reverseScan = reverseDependency; // we need a deep copy here this.pksToScanPerTable = new PkTableMap(allowedPKs); } public void nodeAdded(Object node) { this.tableNames.add(node); if (this.logger.isDebugEnabled()) { this.logger.debug("nodeAdded: " + node); } } public void edgeAdded(ForeignKeyRelationshipEdge edge) { if (this.logger.isDebugEnabled()) { this.logger.debug("edgeAdded: " + edge); } // first add it to the "direct edges" String from = (String) edge.getFrom(); Set edges = (Set) this.fkEdgesPerTable.get(from); if (edges == null) { edges = new HashSet(); this.fkEdgesPerTable.put(from, edges); } if (!edges.contains(edge)) { edges.add(edge); } // then add it to the "reverse edges" String to = (String) edge.getTo(); edges = (Set) this.fkReverseEdgesPerTable.get(to); if (edges == null) { edges = new HashSet(); this.fkReverseEdgesPerTable.put(to, edges); } if (!edges.contains(edge)) { edges.add(edge); } // finally, update the PKs cache updatePkCache(to, edge); } /** * @see AbstractTableFilter */ public boolean isValidName(String tableName) throws DataSetException { // boolean isValid = this.allowedIds.containsKey(tableName); // return isValid; return true; } public ITableIterator iterator(IDataSet dataSet, boolean reversed) throws DataSetException { if (this.logger.isDebugEnabled()) { this.logger.debug("Filter.iterator()"); } try { searchPKs(dataSet); } catch (SQLException e) { throw new DataSetException(e); } return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet.iterator()); } private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException { logger.debug("searchPKs(dataSet={}) - start", dataSet); int counter = 0; while (!this.pksToScanPerTable.isEmpty()) { counter++; if (this.logger.isDebugEnabled()) { this.logger.debug("RUN # " + counter); } for (int i = this.tableNames.size() - 1; i >= 0; i--) { String tableName = (String) this.tableNames.get(i); // TODO: support multi-column PKs String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0] .getColumnName(); Set tmpSet = this.pksToScanPerTable.get(tableName); if (tmpSet != null && !tmpSet.isEmpty()) { Set pksToScan = new HashSet(tmpSet); if (this.logger.isDebugEnabled()) { this.logger.debug("before search: " + tableName + "=>" + pksToScan); } scanPKs(tableName, pkColumn, pksToScan); scanReversePKs(tableName, pksToScan); allowPKs(tableName, pksToScan); removePKsToScan(tableName, pksToScan); } // if } // for removeScannedTables(); } // while if (this.logger.isDebugEnabled()) { this.logger.debug("Finished searchIds()"); } } private void removeScannedTables() { logger.debug("removeScannedTables() - start"); this.pksToScanPerTable.retainOnly(this.tableNames); } private void allowPKs(String table, Set newAllowedPKs) { logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs); // then, add the new IDs, but checking if it should be allowed to add them Set forcedAllowedPKs = this.allowedPKsInput.get(table); if (forcedAllowedPKs == null || forcedAllowedPKs.isEmpty()) { allowedPKsPerTable.addAll(table, newAllowedPKs); } else { for (Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext();) { Object id = iterator.next(); if (forcedAllowedPKs.contains(id)) { allowedPKsPerTable.add(table, id); } else { if (this.logger.isDebugEnabled()) { this.logger.debug("Discarding id " + id + " of table " + table + " as it was not included in the input!"); } } } } } private void scanPKs(String table, String pkColumn, Set allowedIds) throws SQLException { if (logger.isDebugEnabled()) { logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start", new Object[] { table, pkColumn, allowedIds }); } Set fkEdges = (Set) this.fkEdgesPerTable.get(table); if (fkEdges == null || fkEdges.isEmpty()) { return; } // we need a temporary list as there is no warranty about the set order... List fkTables = new ArrayList(fkEdges.size()); StringBuffer colsBuffer = new StringBuffer(); for (Iterator iterator = fkEdges.iterator(); iterator.hasNext();) { ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next(); fkTables.add(edge.getTo()); colsBuffer.append(edge.getFKColumn()); if (iterator.hasNext()) { colsBuffer.append(", "); } } // NOTE: make sure the query below is compatible standard SQL String sql = "SELECT " + colsBuffer + " FROM " + table + " WHERE " + pkColumn + " = ? "; if (this.logger.isDebugEnabled()) { this.logger.debug("SQL: " + sql); } scanPKs(table, sql, allowedIds, fkTables); } private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException { PreparedStatement pstmt = null; ResultSet rs = null; try { pstmt = this.connection.getConnection().prepareStatement(sql); for (Iterator iterator = allowedIds.iterator(); iterator.hasNext();) { Object pk = iterator.next(); // id being scanned if (this.logger.isDebugEnabled()) { this.logger.debug("Executing sql for ? = " + pk); } pstmt.setObject(1, pk); rs = pstmt.executeQuery(); while (rs.next()) { for (int i = 0; i < fkTables.size(); i++) { String newTable = (String) fkTables.get(i); Object fk = rs.getObject(i + 1); if (fk != null) { if (this.logger.isDebugEnabled()) { this.logger.debug("New ID: " + newTable + "->" + fk); } addPKToScan(newTable, fk); } else { this.logger.warn("Found null FK for relationship " + table + "=>" + newTable); } } } } } catch (SQLException e) { logger.error("scanPKs()", e); } finally { // new in the finally block. has been in the catch only before SQLHelper.close(rs, pstmt); } } private void scanReversePKs(String table, Set pksToScan) throws SQLException { logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan); if (!this.reverseScan) { return; } Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get(table); if (fkReverseEdges == null || fkReverseEdges.isEmpty()) { return; } Iterator iterator = fkReverseEdges.iterator(); while (iterator.hasNext()) { ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next(); addReverseEdge(edge, pksToScan); } } private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException { logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan); String fkTable = (String) edge.getFrom(); String fkColumn = edge.getFKColumn(); String pkColumn = getPKColumn(fkTable); // NOTE: make sure the query below is compatible standard SQL String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? "; PreparedStatement pstmt = null; ResultSet rs = null; try { if (this.logger.isDebugEnabled()) { this.logger.debug("Preparing SQL query '" + sql + "'"); } pstmt = this.connection.getConnection().prepareStatement(sql); for (Iterator iterator = idsToScan.iterator(); iterator.hasNext();) { Object pk = iterator.next(); if (this.logger.isDebugEnabled()) { this.logger.debug("executing query '" + sql + "' for ? = " + pk); } pstmt.setObject(1, pk); rs = pstmt.executeQuery(); while (rs.next()) { Object fk = rs.getObject(1); addPKToScan(fkTable, fk); } } } finally { SQLHelper.close(rs, pstmt); } } private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) { logger.debug("updatePkCache(to={}, edge={}) - start", table, edge); Object pkTo = this.pkColumnPerTable.get(table); if (pkTo == null) { String pkColumn = edge.getPKColumn(); this.pkColumnPerTable.put(table, pkColumn); } } // TODO: support PKs with multiple values private String getPKColumn(String table) throws SQLException { logger.debug("getPKColumn(table={}) - start", table); // Try to get the cached column String pkColumn = (String) this.pkColumnPerTable.get(table); if (pkColumn == null) { // If the column has not been cached until now retrieve it from the database connection pkColumn = SQLHelper.getPrimaryKeyColumn(this.connection.getConnection(), table); this.pkColumnPerTable.put(table, pkColumn); } return pkColumn; } private void removePKsToScan(String table, Set ids) { logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids); Set pksToScan = this.pksToScanPerTable.get(table); if (pksToScan != null) { if (pksToScan == ids) { throw new RuntimeException("INTERNAL ERROR on removeIdsToScan() for table " + table); } else { pksToScan.removeAll(ids); } } } private void addPKToScan(String table, Object pk) { logger.debug("addPKToScan(table={}, pk={}) - start", table, pk); // first, check if it wasn't added yet if (this.allowedPKsPerTable.contains(table, pk)) { if (this.logger.isDebugEnabled()) { this.logger.debug("Discarding already scanned id=" + pk + " for table " + table); } return; } this.pksToScanPerTable.add(table, pk); } public String toString() { StringBuffer sb = new StringBuffer(); sb.append("tableNames=").append(tableNames); sb.append(", allowedPKsInput=").append(allowedPKsInput); sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable); sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable); sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable); sb.append(", pkColumnPerTable=").append(pkColumnPerTable); sb.append(", pksToScanPerTable=").append(pksToScanPerTable); sb.append(", reverseScan=").append(reverseScan); sb.append(", connection=").append(connection); return sb.toString(); } private class FilterIterator implements ITableIterator { private final ITableIterator _iterator; public FilterIterator(ITableIterator iterator) { _iterator = iterator; } //////////////////////////////////////////////////////////////////////////// // ITableIterator interface public boolean next() throws DataSetException { if (logger.isDebugEnabled()) { logger.debug("Iterator.next()"); } while (_iterator.next()) { if (accept(_iterator.getTableMetaData().getTableName())) { return true; } } return false; } public ITableMetaData getTableMetaData() throws DataSetException { if (logger.isDebugEnabled()) { logger.debug("Iterator.getTableMetaData()"); } return _iterator.getTableMetaData(); } public ITable getTable() throws DataSetException { if (logger.isDebugEnabled()) { logger.debug("Iterator.getTable()"); } ITable table = _iterator.getTable(); String tableName = table.getTableMetaData().getTableName(); Set allowedPKs = allowedPKsPerTable.get(tableName); if (allowedPKs != null) { return new PrimaryKeyFilteredTableWrapper(table, allowedPKs); } return table; } } /** * Map that associates a table with a set of primary key objects. * * @author gommma (gommma AT users.sourceforge.net) * @author Last changed by: $Author$ * @version $Revision$ $Date$ * @since 2.3.0 */ public static class PkTableMap { private final ListOrderedMap pksPerTable; private final Logger logger = LoggerFactory.getLogger(PkTableMap.class); public PkTableMap() { this.pksPerTable = new ListOrderedMap(); } /** * Copy constructor * @param allowedPKs */ public PkTableMap(PkTableMap allowedPKs) { this.pksPerTable = new ListOrderedMap(); Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry entry = (Map.Entry) iterator.next(); String table = (String) entry.getKey(); SortedSet pkObjectSet = (SortedSet) entry.getValue(); SortedSet newSet = new TreeSet(pkObjectSet); this.pksPerTable.put(table, newSet); } } public int size() { return pksPerTable.size(); } public boolean isEmpty() { return pksPerTable.isEmpty(); } public boolean contains(String table, Object pkObject) { Set pksPerTable = this.get(table); return (pksPerTable != null && pksPerTable.contains(pkObject)); } public void remove(String tableName) { this.pksPerTable.remove(tableName); } public void put(String table, SortedSet pkObjects) { this.pksPerTable.put(table, pkObjects); } public void add(String tableName, Object pkObject) { Set pksPerTable = getCreateIfNeeded(tableName); pksPerTable.add(pkObject); } public void addAll(String tableName, Set pkObjectsToAdd) { Set pksPerTable = this.getCreateIfNeeded(tableName); pksPerTable.addAll(pkObjectsToAdd); } public SortedSet get(String tableName) { return (SortedSet) this.pksPerTable.get(tableName); } private SortedSet getCreateIfNeeded(String tableName) { SortedSet pksPerTable = this.get(tableName); // Lazily create the set if it did not exist yet if (pksPerTable == null) { pksPerTable = new TreeSet(); this.pksPerTable.put(tableName, pksPerTable); } return pksPerTable; } public String[] getTableNames() { return (String[]) this.pksPerTable.keySet().toArray(new String[0]); } public void retainOnly(List tableNames) { List tablesToRemove = new ArrayList(); for (Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext();) { Map.Entry entry = (Map.Entry) iterator.next(); String table = (String) entry.getKey(); SortedSet pksToScan = (SortedSet) entry.getValue(); boolean removeIt = pksToScan.isEmpty(); if (!tableNames.contains(table)) { if (this.logger.isWarnEnabled()) { this.logger.warn("Discarding ids " + pksToScan + " of table " + table + "as this table has not been passed as input"); } removeIt = true; } if (removeIt) { tablesToRemove.add(table); } } for (Iterator iterator = tablesToRemove.iterator(); iterator.hasNext();) { this.remove((String) iterator.next()); } } public String toString() { StringBuffer sb = new StringBuffer(); sb.append("pKsPerTable=").append(pksPerTable); return sb.toString(); } } }