com.yahoo.ycsb.db.PostgreSQLJsonbClient.java Source code

Java tutorial

Introduction

Here is the source code for com.yahoo.ycsb.db.PostgreSQLJsonbClient.java

Source

/**
 * Copyright (c) 2010 - 2016 Yahoo! Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you
 * may not use this file except in compliance with the License. You
 * may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License. See accompanying
 * LICENSE file.
 */

package com.yahoo.ycsb.db;

import com.yahoo.ycsb.*;
import java.sql.*;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.commons.lang3.StringEscapeUtils;

/**
 * A class that wraps PostgreSQL database to benchmark its JSONB features
 * with YCSB. This class extends {@link DB} and implements the database
 * interface used by YCSB client.
 * This class also uses code from jdbc-binding.
 *
 * <br>
 * Each client will have its own instance of this class. This client is not
 * thread safe.
 *
 * <br>
 * This interface expects a schema <key> <fields>.
 * <key> is of type VARCHAR and <fields> is of type JSONB.
 * All accesses are through the primary key.
 * Therefore, only one index on the primary key is needed.
 */
public class PostgreSQLJsonbClient extends DB {

    /** The class to use as the jdbc driver. */
    public static final String DRIVER_CLASS = "db.driver";

    /** The URL to connect to the database. */
    public static final String CONNECTION_URL = "db.url";

    /** The user name to use to connect to the database. */
    public static final String CONNECTION_USER = "db.user";

    /** The password to use for establishing the connection. */
    public static final String CONNECTION_PASSWD = "db.passwd";

    /** The batch size for batched inserts. Set to >0 to use batching */
    public static final String DB_BATCH_SIZE = "db.batchsize";

    /** The JDBC fetch size hinted to the driver. */
    public static final String JDBC_FETCH_SIZE = "jdbc.fetchsize";

    /** The JDBC connection auto-commit property for the driver. */
    public static final String JDBC_AUTO_COMMIT = "jdbc.autocommit";

    /** The name of the property for the number of fields in a record. */
    public static final String FIELD_COUNT_PROPERTY = "fieldcount";

    /** Default number of fields in a record. */
    public static final String FIELD_COUNT_PROPERTY_DEFAULT = "10";

    /** Representing a NULL value. */
    public static final String NULL_VALUE = "NULL";

    /** The primary key in the user table. */
    public static final String PRIMARY_KEY = "YCSB_KEY";

    /** The field name prefix in the table. */
    public static final String COLUMN_PREFIX = "FIELD";

    private ArrayList<Connection> conns;
    private boolean initialized = false;
    private Properties props;
    private int jdbcFetchSize;
    private int batchSize;
    private static final String DEFAULT_PROP = "";
    private ConcurrentMap<StatementType, PreparedStatement> cachedStatements;
    private long numRowsInBatch = 0;

    /**
     * Ordered field information for insert and update statements.
     */
    private static class OrderedFieldInfo {
        private String fieldKeys;
        private List<String> fieldValues;

        OrderedFieldInfo(String fieldKeys, List<String> fieldValues) {
            this.fieldKeys = fieldKeys;
            this.fieldValues = fieldValues;
        }

        String getFieldKeys() {
            return fieldKeys;
        }

        List<String> getFieldValues() {
            return fieldValues;
        }
    }

    /**
     * The statement type for the prepared statements.
     */
    private static class StatementType {

        enum Type {
            INSERT(1), DELETE(2), READ(3), UPDATE(4), SCAN(5);

            private final int internalType;

            private Type(int type) {
                internalType = type;
            }

            int getHashCode() {
                final int prime = 31;
                int result = 1;
                result = prime * result + internalType;
                return result;
            }
        }

        private Type type;
        private int shardIndex;
        private int numFields;
        private String tableName;
        private String fieldString;

        StatementType(Type type, String tableName, int numFields, String fieldString, int shardIndex) {
            this.type = type;
            this.tableName = tableName;
            this.numFields = numFields;
            this.fieldString = fieldString;
            this.shardIndex = shardIndex;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + numFields + 100 * shardIndex;
            result = prime * result + ((tableName == null) ? 0 : tableName.hashCode());
            result = prime * result + ((type == null) ? 0 : type.getHashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (getClass() != obj.getClass()) {
                return false;
            }
            StatementType other = (StatementType) obj;
            if (numFields != other.numFields) {
                return false;
            }
            if (shardIndex != other.shardIndex) {
                return false;
            }
            if (tableName == null) {
                if (other.tableName != null) {
                    return false;
                }
            } else if (!tableName.equals(other.tableName)) {
                return false;
            }
            if (type != other.type) {
                return false;
            }
            if (!fieldString.equals(other.fieldString)) {
                return false;
            }
            return true;
        }
    }

    /**
     * For the given key, returns what shard contains data for this key.
     *
     * @param key Data key to do operation on
     * @return Shard index
     */
    private int getShardIndexByKey(String key) {
        int ret = Math.abs(key.hashCode()) % conns.size();
        return ret;
    }

    /**
     * For the given key, returns Connection object that holds connection to the
     * shard that contains this key.
     *
     * @param key Data key to get information for
     * @return Connection object
     */
    private Connection getShardConnectionByKey(String key) {
        return conns.get(getShardIndexByKey(key));
    }

    private void cleanupAllConnections() throws SQLException {
        for (Connection conn : conns) {
            conn.close();
        }
    }

    /** Returns parsed int value from the properties if set, otherwise returns -1. */
    private static int getIntProperty(Properties props, String key) throws DBException {
        String valueStr = props.getProperty(key);
        if (valueStr != null) {
            try {
                return Integer.parseInt(valueStr);
            } catch (NumberFormatException nfe) {
                System.err.println("Invalid " + key + " specified: " + valueStr);
                throw new DBException(nfe);
            }
        }
        return -1;
    }

    @Override
    public void init() throws DBException {
        if (initialized) {
            System.err.println("Client connection already initialized.");
            return;
        }
        props = getProperties();
        String urls = props.getProperty(CONNECTION_URL, DEFAULT_PROP);
        String user = props.getProperty(CONNECTION_USER, DEFAULT_PROP);
        String passwd = props.getProperty(CONNECTION_PASSWD, DEFAULT_PROP);
        String driver = props.getProperty(DRIVER_CLASS);

        this.jdbcFetchSize = getIntProperty(props, JDBC_FETCH_SIZE);
        this.batchSize = getIntProperty(props, DB_BATCH_SIZE);

        String autoCommitStr = props.getProperty(JDBC_AUTO_COMMIT, Boolean.TRUE.toString());
        Boolean autoCommit = Boolean.parseBoolean(autoCommitStr);

        try {
            if (driver != null) {
                Class.forName(driver);
            }
            int shardCount = 0;
            conns = new ArrayList<Connection>(3);
            for (String url : urls.split(",")) {
                System.out.println("Adding shard node URL: " + url);
                Connection conn = DriverManager.getConnection(url, user, passwd);

                // Since there is no explicit commit method in the DB interface, all
                // operations should auto commit, except when explicitly told not to
                // (this is necessary in cases such as for PostgreSQL when running a
                // scan workload with fetchSize)
                conn.setAutoCommit(autoCommit);

                shardCount++;
                conns.add(conn);
            }

            System.out.println(
                    "Using shards: " + shardCount + ", batchSize:" + batchSize + ", fetchSize: " + jdbcFetchSize);

            cachedStatements = new ConcurrentHashMap<StatementType, PreparedStatement>();
        } catch (ClassNotFoundException e) {
            System.err.println("Error in initializing the JDBS driver: " + e);
            throw new DBException(e);
        } catch (SQLException e) {
            System.err.println("Error in database operation: " + e);
            throw new DBException(e);
        } catch (NumberFormatException e) {
            System.err.println("Invalid value for fieldcount property. " + e);
            throw new DBException(e);
        }
        initialized = true;
    }

    @Override
    public void cleanup() throws DBException {
        if (batchSize > 0) {
            try {
                // commit un-finished batches
                for (PreparedStatement st : cachedStatements.values()) {
                    if (!st.getConnection().isClosed() && !st.isClosed() && (numRowsInBatch % batchSize != 0)) {
                        st.executeBatch();
                    }
                }
            } catch (SQLException e) {
                System.err.println("Error in cleanup execution. " + e);
                throw new DBException(e);
            }
        }

        try {
            cleanupAllConnections();
        } catch (SQLException e) {
            System.err.println("Error in closing the connection. " + e);
            throw new DBException(e);
        }
    }

    private PreparedStatement createAndCacheInsertStatement(StatementType insertType, String key)
            throws SQLException {
        StringBuilder insert = new StringBuilder("INSERT INTO ");
        insert.append(insertType.tableName);
        insert.append(" (" + PRIMARY_KEY + "," + insertType.fieldString + ")");
        insert.append(" VALUES(?");
        for (int i = 0; i < insertType.numFields; i++) {
            insert.append(",?");
        }
        insert.append(")");
        PreparedStatement insertStatement = getShardConnectionByKey(key).prepareStatement(insert.toString());
        PreparedStatement stmt = cachedStatements.putIfAbsent(insertType, insertStatement);
        if (stmt == null) {
            return insertStatement;
        }
        return stmt;
    }

    private PreparedStatement createAndCacheReadStatement(StatementType readType, String key) throws SQLException {
        StringBuilder read = new StringBuilder("SELECT * FROM ");
        read.append(readType.tableName);
        read.append(" WHERE ");
        read.append(PRIMARY_KEY);
        read.append(" = ");
        read.append("?");
        PreparedStatement readStatement = getShardConnectionByKey(key).prepareStatement(read.toString());
        PreparedStatement stmt = cachedStatements.putIfAbsent(readType, readStatement);
        if (stmt == null) {
            return readStatement;
        }
        return stmt;
    }

    private PreparedStatement createAndCacheDeleteStatement(StatementType deleteType, String key)
            throws SQLException {
        StringBuilder delete = new StringBuilder("DELETE FROM ");
        delete.append(deleteType.tableName);
        delete.append(" WHERE ");
        delete.append(PRIMARY_KEY);
        delete.append(" = ?");
        PreparedStatement deleteStatement = getShardConnectionByKey(key).prepareStatement(delete.toString());
        PreparedStatement stmt = cachedStatements.putIfAbsent(deleteType, deleteStatement);
        if (stmt == null) {
            return deleteStatement;
        }
        return stmt;
    }

    private PreparedStatement createAndCacheUpdateStatement(StatementType updateType, String key)
            throws SQLException {
        String[] fieldKeys = updateType.fieldString.split(",");
        StringBuilder update = new StringBuilder("UPDATE ");
        update.append(updateType.tableName);
        update.append(" SET ");
        for (int i = 0; i < fieldKeys.length; i++) {
            update.append(fieldKeys[i]);
            update.append("=jsonb_set(fields, ?, ?)");
        }
        update.append(" WHERE ");
        update.append(PRIMARY_KEY);
        update.append(" = ?");
        PreparedStatement insertStatement = getShardConnectionByKey(key).prepareStatement(update.toString());
        PreparedStatement stmt = cachedStatements.putIfAbsent(updateType, insertStatement);
        if (stmt == null) {
            return insertStatement;
        }
        return stmt;
    }

    private PreparedStatement createAndCacheScanStatement(StatementType scanType, String key) throws SQLException {
        StringBuilder select = new StringBuilder("SELECT * FROM ");
        select.append(scanType.tableName);
        select.append(" WHERE ");
        select.append(PRIMARY_KEY);
        select.append(" >= ?");
        select.append(" ORDER BY ");
        select.append(PRIMARY_KEY);
        select.append(" LIMIT ?");
        PreparedStatement scanStatement = getShardConnectionByKey(key).prepareStatement(select.toString());
        if (this.jdbcFetchSize > 0) {
            scanStatement.setFetchSize(this.jdbcFetchSize);
        }
        PreparedStatement stmt = cachedStatements.putIfAbsent(scanType, scanStatement);
        if (stmt == null) {
            return scanStatement;
        }
        return stmt;
    }

    @Override
    public Status read(String tableName, String key, Set<String> fields, HashMap<String, ByteIterator> result) {
        try {
            StatementType type = new StatementType(StatementType.Type.READ, tableName, 1, "",
                    getShardIndexByKey(key));
            PreparedStatement readStatement = cachedStatements.get(type);
            if (readStatement == null) {
                readStatement = createAndCacheReadStatement(type, key);
            }
            readStatement.setString(1, key);
            ResultSet resultSet = readStatement.executeQuery();
            if (!resultSet.next()) {
                resultSet.close();
                return Status.NOT_FOUND;
            }

            if (result != null && fields != null) {
                for (String field : fields) {
                    String value = resultSet.getString(field);
                    result.put(field, new StringByteIterator(value));
                }
            }
            resultSet.close();
            return Status.OK;
        } catch (SQLException e) {
            System.err.println("Error in processing read of table " + tableName + ": " + e);
            return Status.ERROR;
        }
    }

    @Override
    public Status scan(String tableName, String startKey, int recordcount, Set<String> fields,
            Vector<HashMap<String, ByteIterator>> result) {
        try {
            StatementType type = new StatementType(StatementType.Type.SCAN, tableName, 1, "",
                    getShardIndexByKey(startKey));
            PreparedStatement scanStatement = cachedStatements.get(type);
            if (scanStatement == null) {
                scanStatement = createAndCacheScanStatement(type, startKey);
            }
            scanStatement.setString(1, startKey);
            scanStatement.setInt(2, recordcount);
            result.ensureCapacity(recordcount);
            ResultSet resultSet = scanStatement.executeQuery();
            System.out.println(startKey + ", " + recordcount);
            for (int i = 0; i < recordcount && resultSet.next(); i++) {
                if (result != null && fields != null) {
                    HashMap<String, ByteIterator> values = new HashMap<String, ByteIterator>();
                    for (String field : fields) {
                        String value = resultSet.getString(field);
                        //values.put(field, new StringByteIterator(value));
                    }
                    result.add(values);
                }
            }
            resultSet.close();
            return Status.OK;
        } catch (SQLException e) {
            System.err.println("Error in processing scan of table: " + tableName + e);
            return Status.ERROR;
        }
    }

    @Override
    public Status update(String tableName, String key, HashMap<String, ByteIterator> values) {
        try {
            int numFields = values.size();
            OrderedFieldInfo fieldInfo = getFieldInfo(values);
            StatementType type = new StatementType(StatementType.Type.UPDATE, tableName, 1, "fields",
                    getShardIndexByKey(key));
            PreparedStatement updateStatement = cachedStatements.get(type);
            if (updateStatement == null) {
                updateStatement = createAndCacheUpdateStatement(type, key);
            }
            for (Map.Entry<String, String> entry : StringByteIterator.getStringMap(values).entrySet()) {
                updateStatement.setString(1, "{0, " + entry.getKey() + "}");
                updateStatement.setString(2, "\"" + StringEscapeUtils.escapeJson(entry.getValue()) + "\"");
            }
            updateStatement.setString(3, key);
            int result = updateStatement.executeUpdate();
            if (result == 1) {
                return Status.OK;
            }
            return Status.UNEXPECTED_STATE;
        } catch (SQLException e) {
            System.err.println("Error in processing update to table: " + tableName + e);
            return Status.ERROR;
        }
    }

    @Override
    public Status insert(String tableName, String key, HashMap<String, ByteIterator> values) {
        try {
            StatementType type = new StatementType(StatementType.Type.INSERT, tableName, 1, "fields",
                    getShardIndexByKey(key));
            PreparedStatement insertStatement = cachedStatements.get(type);
            if (insertStatement == null) {
                insertStatement = createAndCacheInsertStatement(type, key);
            }
            String value = getJsonString(values);
            insertStatement.setString(1, key);
            insertStatement.setString(2, value);
            int result;
            if (batchSize > 0) {
                insertStatement.addBatch();
                if (++numRowsInBatch % batchSize == 0) {
                    int[] results = insertStatement.executeBatch();
                    for (int r : results) {
                        if (r != 1) {
                            return Status.ERROR;
                        }
                    }
                    return Status.OK;
                }
                return Status.BATCHED_OK;
            } else {
                result = insertStatement.executeUpdate();
            }
            if (result == 1) {
                return Status.OK;
            }
            return Status.UNEXPECTED_STATE;
        } catch (SQLException e) {
            System.err.println("Error in processing insert to table: " + tableName + e);
            return Status.ERROR;
        }
    }

    @Override
    public Status delete(String tableName, String key) {
        try {
            StatementType type = new StatementType(StatementType.Type.DELETE, tableName, 1, "",
                    getShardIndexByKey(key));
            PreparedStatement deleteStatement = cachedStatements.get(type);
            if (deleteStatement == null) {
                deleteStatement = createAndCacheDeleteStatement(type, key);
            }
            deleteStatement.setString(1, key);
            int result = deleteStatement.executeUpdate();
            if (result == 1) {
                return Status.OK;
            }
            return Status.UNEXPECTED_STATE;
        } catch (SQLException e) {
            System.err.println("Error in processing delete to table: " + tableName + e);
            return Status.ERROR;
        }
    }

    private OrderedFieldInfo getFieldInfo(HashMap<String, ByteIterator> values) {
        String fieldKeys = "";
        List<String> fieldValues = new ArrayList();
        int count = 0;
        for (Map.Entry<String, ByteIterator> entry : values.entrySet()) {
            fieldKeys += entry.getKey();
            if (count < values.size() - 1) {
                fieldKeys += ",";
            }
            fieldValues.add(count, entry.getValue().toString());
            count++;
        }

        return new OrderedFieldInfo(fieldKeys, fieldValues);
    }

    private String getJsonString(HashMap<String, ByteIterator> values) {
        StringBuilder builder = new StringBuilder("{");
        int count = 0;
        for (Map.Entry<String, String> entry : StringByteIterator.getStringMap(values).entrySet()) {
            builder.append("\"" + entry.getKey() + "\":");
            builder.append("\"" + StringEscapeUtils.escapeJson(entry.getValue()) + "\"");
            if (count < values.size() - 1) {
                builder.append(",");
            }
            count++;
        }
        builder.append("}");

        return builder.toString();
    }

    private String getJsonUpdateString(HashMap<String, ByteIterator> values) {
        String prevColumn = "fields";

        for (Map.Entry<String, String> entry : StringByteIterator.getStringMap(values).entrySet()) {
            StringBuilder builder = new StringBuilder();
            builder.append("jsonb_set(" + prevColumn + ", '{\"" + entry.getKey() + "\"}', ");
            builder.append("'\"" + StringEscapeUtils.escapeJson(entry.getValue()) + "\"')");
            prevColumn = builder.toString();
        }

        return prevColumn;
    }
}