com.sf.ddao.shards.ShardedDaoTest.java Source code

Java tutorial

Introduction

Here is the source code for com.sf.ddao.shards.ShardedDaoTest.java

Source

/*
 * Copyright 2008 Pavel Syrtsov
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and limitations
 *  under the License.
 */

package com.sf.ddao.shards;

import com.google.inject.Guice;
import com.google.inject.Injector;
import com.mockrunner.jdbc.JDBCTestModule;
import com.mockrunner.jdbc.PreparedStatementResultSetHandler;
import com.mockrunner.mock.jdbc.JDBCMockObjectFactory;
import com.mockrunner.mock.jdbc.MockResultSet;
import com.sf.ddao.*;
import com.sf.ddao.chain.ChainModule;
import com.sf.ddao.factory.param.ThreadLocalParameter;
import com.sf.ddao.orm.RSMapper;
import com.sf.ddao.orm.UseRSMapper;
import com.sf.ddao.orm.rsmapper.rowmapper.BeanRowMapperFactory;
import com.sf.ddao.orm.rsmapper.rowmapper.RowMapper;
import junit.framework.TestCase;
import org.apache.commons.chain.Context;
import org.mockejb.jndi.MockContextFactory;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * Created by: Pavel Syrtsov
 * Date: Apr 6, 2007
 * Time: 7:00:11 PM
 */
public class ShardedDaoTest extends TestCase {
    Injector injector;
    private static final String PART_NAME = "testPartName";
    private JDBCTestModule testModule1;
    private JDBCTestModule testModule2;

    @ShardedDao(TestShardingService.class)
    public static interface TestUserDao extends TransactionableDao {
        /**
         * in this statement we assume that 1st method arg is Java Bean
         * and refer to property by name. It works same way for Map.
         *
         * @param userBean - parameter object
         * @return object created from data returned by sql
         */
        @Select("select id, name from user where id = #id#")
        TestUserBean getUser(@ShardKey("id") TestUserBean userBean);

        /**
         * MultiShardSelect annotation allows to execute SQl statement
         * on multiple shards and takes care of merging results from multiple shards
         * by default it assumes that result is a collection of objects and will do merge of collections.
         * To provide custom merger logic annotation allows to define value for resultMerger class.
         *
         * @param userIdList
         * @return merged
         */
        @MultiShardSelect("select id, name from user_data where user_id in ($ctx:keyList$)")
        List<TestUserBean> getUserDataList(@ShardKey List<Integer> userIdList);

        /**
         * 1st parameter passed by reference, 2nd by value (by injecting result of toString() into SQL).
         *
         * @param tableName name of table
         * @param size      - max size of array
         * @param userId    - query parameter
         * @return objects created from data returned by sql
         */
        @Select("select id, name from $0$ where user_id = #2# limit #1#")
        TestUserBean[] getUserDataArray(String tableName, int size, @ShardKey long userId);

        @Select("select id, name from user_data where user_id = #0#")
        void processUserData(@ShardKey long userId, @UseRSMapper RSMapper selectCallback);

        /**
         * values that have ':' with prefix assumed to be call to predefined static function registered by ParameterService,
         * there are few if them predefined:
         * prefix threadLocal: allows to pass value using ThreadLocal
         * prefix ctx: allows to pass value using Context object in method arguments
         * prefix joinList: allows to join list of keys in comma separated string
         *
         * @param userId - query paramter
         * @return value returned by query
         */
        @Select("select id from user_data where part = '$threadLocal:" + PART_NAME + "$' and user_id = #0#")
        int getUserData(@ShardKey long userId);

        @SelectThenInsert({ "select nextval from userIdSequence",
                "insert into user(id,name) values(#threadLocal:id#, #name#)" })
        long addUser(@ShardKey("id") TestUserBean user);

    }

    protected void setUp() throws Exception {
        this.injector = Guice.createInjector(new ChainModule(TestUserDao.class));
        super.setUp();
        MockContextFactory.setAsInitial();

        JDBCMockObjectFactory mockFactory1 = new JDBCMockObjectFactory();
        testModule1 = new JDBCTestModule(mockFactory1);
        JDBCMockObjectFactory mockFactory2 = new JDBCMockObjectFactory();
        testModule2 = new JDBCTestModule(mockFactory2);

        final TestShardingService controlDao = injector.getInstance(TestShardingService.class);
        controlDao.setDS1(mockFactory1.getMockDataSource());
        controlDao.setDS2(mockFactory2.getMockDataSource());
    }

    protected void tearDown() throws Exception {
        super.tearDown();
        MockContextFactory.revertSetAsInitial();
    }

    private void createResultSet(JDBCTestModule testModule, Object... data) {
        PreparedStatementResultSetHandler handler = testModule.getPreparedStatementResultSetHandler();
        MockResultSet rs = handler.createResultSet();
        for (int i = 0; i < data.length; i++) {
            Object colName = data[i++];
            Object colValues = data[i];
            rs.addColumn(colName.toString(), (Object[]) colValues);
        }
        handler.prepareGlobalResultSet(rs);
    }

    public void testSingleRecordGet() throws Exception {
        // create dao object
        TestUserDao dao = injector.getInstance(TestUserDao.class);

        // reuse it for multiple invocations
        getUserOnce(testModule1, dao, 1, "foo1", false);
        getUserOnce(testModule1, dao, 10, "foo2", false);
        getUserOnce(testModule2, dao, 11, "bar1", false);
        getUserOnce(testModule2, dao, 20, "bar2", false);
    }

    private void getUserOnce(JDBCTestModule testModule, TestUserDao dao, int id, String name, boolean inTx)
            throws Exception {
        // setup test
        TestUserBean data = new TestUserBean(true);
        data.setId(id);
        data.setName(name);
        createResultSet(testModule, "id", new Object[] { data.getId() }, "name", new Object[] { data.getName() });

        // execute dao method
        TestUserBean res = dao.getUser(data);

        // verify result
        assertNotNull(res);
        assertEquals(res.getId(), data.getId());
        assertEquals(res.getName(), data.getName());

        testModule.verifySQLStatementExecuted("select id, name from user where id = ?");
        testModule.verifyAllResultSetsClosed();
        testModule.verifyAllStatementsClosed();
        if (!inTx) {
            testModule.verifyConnectionClosed();
        }
    }

    public void testMultiShardGetRecordList() throws Exception {
        TestUserDao dao = injector.getInstance(TestUserDao.class);
        // setup test
        createResultSet(testModule1, "id", new Object[] { 1, 2 }, "name", new Object[] { "u1", "u2" });
        createResultSet(testModule2, "id", new Object[] { 15, 16 }, "name", new Object[] { "u15", "u16" });

        List<Integer> userIdList = new ArrayList<Integer>();
        userIdList.add(1);
        userIdList.add(2);
        userIdList.add(15);
        userIdList.add(16);
        // execute dao method
        List<TestUserBean> res = dao.getUserDataList(userIdList);

        Collections.sort(res, new Comparator<TestUserBean>() {
            public int compare(TestUserBean testUserBean, TestUserBean testUserBean1) {
                return (int) (testUserBean.getId() - testUserBean1.getId());
            }
        });

        // verify result
        assertNotNull(res);
        assertEquals(4, res.size());

        assertEquals(1, res.get(0).getId());
        assertEquals("u1", res.get(0).getName());
        assertEquals(2, res.get(1).getId());
        assertEquals("u2", res.get(1).getName());

        assertEquals(15, res.get(2).getId());
        assertEquals("u15", res.get(2).getName());
        assertEquals(16, res.get(3).getId());
        assertEquals("u16", res.get(3).getName());

        testModule1.verifySQLStatementExecuted("select id, name from user");
        testModule1.verifyAllResultSetsClosed();
        testModule1.verifyAllStatementsClosed();
        testModule1.verifyConnectionClosed();

        testModule2.verifySQLStatementExecuted("select id, name from user");
        testModule2.verifyAllResultSetsClosed();
        testModule2.verifyAllStatementsClosed();
        testModule2.verifyConnectionClosed();
    }

    public void testGetUserArray() throws Exception {
        // execute dao method
        TestUserDao dao = injector.getInstance(TestUserDao.class);
        getUserDataArray(dao, testModule1, 1);
        getUserDataArray(dao, testModule1, 10);
        getUserDataArray(dao, testModule2, 11);
        getUserDataArray(dao, testModule2, 20);

    }

    private void getUserDataArray(TestUserDao dao, JDBCTestModule testModule, int userId) {
        // setup test
        createResultSet(testModule, "id", new Object[] { 1, 2 }, "name", new Object[] { "foo", "bar" });

        TestUserBean[] res = dao.getUserDataArray("user", 2, userId);

        // verify result
        assertNotNull(res);
        assertEquals(res.length, 2);
        assertEquals(res[0].getId(), 1);
        assertEquals(res[0].getName(), "foo");
        assertEquals(res[1].getId(), 2);
        assertEquals(res[1].getName(), "bar");

        testModule.verifySQLStatementExecuted("select id, name from user where user_id = ? limit ?");
        testModule.verifySQLStatementParameter("select id, name from user where user_id = ? limit ?", 0, 2, 2);
        testModule.verifyAllResultSetsClosed();
        testModule.verifyAllStatementsClosed();
        testModule.verifyConnectionClosed();
    }

    public void testSelectCallback() throws Exception {
        TestUserDao dao = injector.getInstance(TestUserDao.class);
        processUserData(dao, testModule1);

    }

    private void processUserData(TestUserDao dao, JDBCTestModule testModule) {
        // setup test
        createResultSet(testModule, "id", new Object[] { 1, 2 }, "name", new Object[] { "foo", "bar" });
        final List<TestUserBean> res = new ArrayList<TestUserBean>();

        // execute dao method
        dao.processUserData(1, new RSMapper() {
            RowMapper rowMapper = new BeanRowMapperFactory(TestUserBean.class).get();

            public Object handle(Context context, ResultSet rs) throws SQLException {
                while (rs.next()) {
                    res.add((TestUserBean) rowMapper.map(rs));
                }
                return null;

            }
        });

        // verify result
        assertNotNull(res);
        assertEquals(res.size(), 2);
        assertEquals(res.get(0).getId(), 1);
        assertEquals(res.get(0).getName(), "foo");
        assertEquals(res.get(1).getId(), 2);
        assertEquals(res.get(1).getName(), "bar");

        testModule.verifySQLStatementExecuted("select id, name from user");
        testModule.verifyAllResultSetsClosed();
        testModule.verifyAllStatementsClosed();
        testModule.verifyConnectionClosed();
    }

    public void testUsingStaticFunction() throws Exception {
        TestUserDao dao = injector.getInstance(TestUserDao.class);
        getUserData(dao, 1, testModule1, 0);
        getUserData(dao, 10, testModule1, 1);
        getUserData(dao, 11, testModule2, 0);
        getUserData(dao, 20, testModule2, 1);

    }

    private void getUserData(TestUserDao dao, long userId, JDBCTestModule testModule, int idx) {
        // setup test
        final int id = 11;
        final String testPart = "testPart";
        createResultSet(testModule, "id", new Object[] { id });
        ThreadLocalParameter.put(PART_NAME, testPart);

        // execute dao method
        int res = dao.getUserData(userId);

        // verify result
        ThreadLocalParameter.remove(PART_NAME);
        assertEquals(id, res);

        testModule.verifySQLStatementExecuted(
                "select id from user_data where part = '" + testPart + "' and user_id = ?");
        testModule.verifyPreparedStatementParameter(idx, 1, userId);
        testModule.verifyAllResultSetsClosed();
        testModule.verifyAllStatementsClosed();
        testModule.verifyConnectionClosed();
    }

    public void testTx() throws Exception {
        final long id = 7;
        final String testName = "testName";

        // execute dao method
        final TestUserDao dao = injector.getInstance(TestUserDao.class);
        final TestUserBean user = new TestUserBean(true);
        user.setName(testName);
        TxHelper.execInTx(dao, new Runnable() {
            public void run() {
                try {
                    createResultSet(testModule1, "nextval", new Object[] { id });
                    final long res = dao.addUser(user);
                    final Connection connection1 = TxHelper.getConnectionOnHold();
                    assertNotNull(connection1);
                    assertFalse(connection1.isClosed());
                    testModule1.verifyNotCommitted();
                    getUserOnce(testModule1, dao, 11, "user11", true);
                    final Connection connection2 = TxHelper.getConnectionOnHold();
                    assertSame(connection1, connection2);
                    assertFalse(connection2.isClosed());
                    testModule1.verifyNotCommitted();
                    assertEquals(id, res);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }, id);
        final Connection connection = TxHelper.getConnectionOnHold();
        assertNull(connection);
        testModule1.verifyCommitted();
        testModule1.verifySQLStatementExecuted("select nextval from userIdSequence");
        testModule1.verifySQLStatementExecuted("insert into user(id,name) values(?, ?)");
        testModule1.verifyPreparedStatementParameter(1, 1, id);
        testModule1.verifyPreparedStatementParameter(1, 2, testName);
        testModule1.verifyAllResultSetsClosed();
        testModule1.verifyAllStatementsClosed();
        testModule1.verifyConnectionClosed();
    }

}