com.inmobi.grill.driver.hive.TestRemoteHiveDriver.java Source code

Java tutorial

Introduction

Here is the source code for com.inmobi.grill.driver.hive.TestRemoteHiveDriver.java

Source

package com.inmobi.grill.driver.hive;

/*
 * #%L
 * Grill Hive Driver
 * %%
 * Copyright (C) 2014 Inmobi
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import static org.testng.Assert.assertEquals;

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import com.inmobi.grill.server.api.driver.DriverQueryPlan;
import com.inmobi.grill.server.api.driver.DriverQueryStatus.DriverQueryState;
import com.inmobi.grill.server.api.driver.GrillDriver;

import org.apache.commons.io.FileUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.Database;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hive.service.server.HiveServer2;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import com.inmobi.grill.api.GrillException;
import com.inmobi.grill.api.query.QueryHandle;
import com.inmobi.grill.server.api.GrillConfConstants;
import com.inmobi.grill.server.api.query.QueryContext;

public class TestRemoteHiveDriver extends TestHiveDriver {
    public static final Log LOG = LogFactory.getLog(TestRemoteHiveDriver.class);
    static final String HS2_HOST = "localhost";
    static final int HS2_PORT = 12345;
    static HiveServer2 server;
    private static HiveConf remoteConf = new HiveConf();

    @BeforeClass
    public static void setupTest() throws Exception {
        createHS2Service();

        SessionState.start(remoteConf);
        Hive client = Hive.get(remoteConf);
        Database database = new Database();
        database.setName(TestRemoteHiveDriver.class.getSimpleName());
        client.createDatabase(database, true);
        SessionState.get().setCurrentDatabase(TestRemoteHiveDriver.class.getSimpleName());
    }

    public static void createHS2Service() throws Exception {
        remoteConf.setClass(HiveDriver.GRILL_HIVE_CONNECTION_CLASS, RemoteThriftConnection.class,
                ThriftConnection.class);
        remoteConf.set("hive.lock.manager", "org.apache.hadoop.hive.ql.lockmgr.EmbeddedLockManager");
        remoteConf.setVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST, HS2_HOST);
        remoteConf.setIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_PORT, HS2_PORT);
        remoteConf.setIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_CLIENT_CONNECTION_RETRY_LIMIT, 3);
        remoteConf.setIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_CLIENT_RETRY_LIMIT, 3);
        remoteConf.setIntVar(HiveConf.ConfVars.SERVER_READ_SOCKET_TIMEOUT, 60000);
        remoteConf.setLong(HiveDriver.GRILL_CONNECTION_EXPIRY_DELAY, 10000);
        server = new HiveServer2();
        server.init(remoteConf);
        server.start();
        // TODO figure out a better way to wait for thrift service to start
        Thread.sleep(7000);
    }

    @AfterClass
    public static void cleanupTest() throws Exception {
        stopHS2Service();
        Hive.get(remoteConf).dropDatabase(TestRemoteHiveDriver.class.getSimpleName(), true, true, true);
    }

    public static void stopHS2Service() throws Exception {
        try {
            server.stop();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @BeforeMethod
    @Override
    public void beforeTest() throws Exception {
        conf = new HiveConf(remoteConf);
        // Check if hadoop property set
        System.out.println("###HADOOP_PATH " + System.getProperty("hadoop.bin.path"));
        Assert.assertNotNull(System.getProperty("hadoop.bin.path"));
        driver = new HiveDriver();
        driver.configure(conf);
        conf.setBoolean(GrillConfConstants.GRILL_ADD_INSERT_OVEWRITE, false);
        conf.setBoolean(GrillConfConstants.GRILL_PERSISTENT_RESULT_SET, false);
        driver.execute(new QueryContext("USE " + TestRemoteHiveDriver.class.getSimpleName(), null, conf));
        conf.setBoolean(GrillConfConstants.GRILL_ADD_INSERT_OVEWRITE, true);
        conf.setBoolean(GrillConfConstants.GRILL_PERSISTENT_RESULT_SET, true);
        Assert.assertEquals(0, driver.getHiveHandleSize());
    }

    @AfterMethod
    @Override
    public void afterTest() throws Exception {
        LOG.info("Test finished, closing driver");
        driver.close();
    }

    @Test
    public void testMultiThreadClient() throws Exception {
        LOG.info("@@ Starting multi thread test");
        // Launch two threads
        createTestTable("test_multithreads");
        HiveConf thConf = new HiveConf(conf, TestRemoteHiveDriver.class);
        thConf.setLong(HiveDriver.GRILL_CONNECTION_EXPIRY_DELAY, 10000);
        final HiveDriver thrDriver = new HiveDriver();
        thrDriver.configure(thConf);
        QueryContext ctx = new QueryContext("USE " + TestRemoteHiveDriver.class.getSimpleName(), null, conf);
        thrDriver.execute(ctx);

        // Launch a select query
        final int QUERIES = 5;
        int launchedQueries = 0;
        final int THREADS = 5;
        final long POLL_DELAY = 500;
        List<Thread> thrs = new ArrayList<Thread>();
        final AtomicInteger errCount = new AtomicInteger();
        for (int q = 0; q < QUERIES; q++) {
            final QueryContext qctx;
            try {
                qctx = new QueryContext("SELECT * FROM test_multithreads", null, conf);
                thrDriver.executeAsync(qctx);
            } catch (GrillException e) {
                errCount.incrementAndGet();
                LOG.info(q + " executeAsync error: " + e.getCause());
                continue;
            }
            LOG.info("@@ Launched query: " + q + " " + qctx.getQueryHandle());
            launchedQueries++;
            // Launch many threads to poll for status
            final QueryHandle handle = qctx.getQueryHandle();

            for (int i = 0; i < THREADS; i++) {
                int thid = q * THREADS + i;
                Thread th = new Thread(new Runnable() {
                    @Override
                    public void run() {
                        for (int i = 0; i < 1000; i++) {
                            try {
                                thrDriver.updateStatus(qctx);
                                if (qctx.getDriverStatus().isFinished()) {
                                    LOG.info("@@ " + handle.getHandleId() + " >> "
                                            + qctx.getDriverStatus().getState());
                                    thrDriver.closeQuery(handle);
                                    break;
                                }
                                Thread.sleep(POLL_DELAY);
                            } catch (GrillException e) {
                                LOG.error("Got Exception", e.getCause());
                                e.printStackTrace();
                                errCount.incrementAndGet();
                                break;
                            } catch (InterruptedException e) {
                                e.printStackTrace();
                                break;
                            }
                        }
                    }
                });
                thrs.add(th);
                th.setName("Poller#" + (thid));
                th.start();
            }
        }

        for (Thread th : thrs) {
            try {
                th.join(10000);
            } catch (InterruptedException e) {
                LOG.warn("Not ended yet: " + th.getName());
            }
        }
        Assert.assertEquals(0, thrDriver.getHiveHandleSize());
        LOG.info("@@ Completed all pollers. Total thrift errors: " + errCount.get());
        assertEquals(launchedQueries, QUERIES);
        assertEquals(thrs.size(), QUERIES * THREADS);
        assertEquals(errCount.get(), 0);
    }

    @Test
    public void testHiveDriverPersistence() throws Exception {
        System.out.println("@@@@ start_persistence_test");
        HiveConf driverConf = new HiveConf(remoteConf, TestRemoteHiveDriver.class);
        driverConf.setLong(HiveDriver.GRILL_CONNECTION_EXPIRY_DELAY, 10000);

        final HiveDriver oldDriver = new HiveDriver();
        oldDriver.configure(driverConf);

        driverConf.setBoolean(GrillConfConstants.GRILL_ADD_INSERT_OVEWRITE, false);
        driverConf.setBoolean(GrillConfConstants.GRILL_PERSISTENT_RESULT_SET, false);
        QueryContext ctx = new QueryContext("USE " + TestRemoteHiveDriver.class.getSimpleName(), null, driverConf);
        oldDriver.execute(ctx);
        Assert.assertEquals(0, oldDriver.getHiveHandleSize());

        String tableName = "test_hive_driver_persistence";

        // Create some ops with a driver
        String createTable = "CREATE TABLE IF NOT EXISTS " + tableName + "(ID STRING)";
        ctx = new QueryContext(createTable, null, driverConf);
        oldDriver.execute(ctx);

        // Load some data into the table
        String dataLoad = "LOAD DATA LOCAL INPATH '" + TEST_DATA_FILE + "' OVERWRITE INTO TABLE " + tableName;
        ctx = new QueryContext(dataLoad, null, driverConf);
        oldDriver.execute(ctx);

        driverConf.setBoolean(GrillConfConstants.GRILL_ADD_INSERT_OVEWRITE, true);
        driverConf.setBoolean(GrillConfConstants.GRILL_PERSISTENT_RESULT_SET, true);
        // Fire two queries
        QueryContext ctx1 = new QueryContext("SELECT * FROM " + tableName, null, driverConf);
        oldDriver.executeAsync(ctx1);
        QueryContext ctx2 = new QueryContext("SELECT ID FROM " + tableName, null, driverConf);
        oldDriver.executeAsync(ctx2);
        Assert.assertEquals(2, oldDriver.getHiveHandleSize());

        byte[] ctx1bytes = persistContext(ctx1);
        byte[] ctx2bytes = persistContext(ctx2);

        // Write driver to stream
        ByteArrayOutputStream driverBytes = new ByteArrayOutputStream();
        try {
            oldDriver.writeExternal(new ObjectOutputStream(driverBytes));
        } finally {
            driverBytes.close();
        }

        // Create another driver from the stream
        ByteArrayInputStream driverInput = new ByteArrayInputStream(driverBytes.toByteArray());
        HiveDriver newDriver = new HiveDriver();
        newDriver.readExternal(new ObjectInputStream(driverInput));
        newDriver.configure(driverConf);
        driverInput.close();

        ctx1 = readContext(ctx1bytes, newDriver);
        ctx2 = readContext(ctx2bytes, newDriver);

        Assert.assertEquals(2, newDriver.getHiveHandleSize());

        validateExecuteAsync(ctx1, DriverQueryState.SUCCESSFUL, true,
                GrillConfConstants.GRILL_RESULT_SET_PARENT_DIR_DEFAULT, false, newDriver);
        validateExecuteAsync(ctx2, DriverQueryState.SUCCESSFUL, true,
                GrillConfConstants.GRILL_RESULT_SET_PARENT_DIR_DEFAULT, false, newDriver);
    }

    private byte[] persistContext(QueryContext ctx) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutputStream out = new ObjectOutputStream(baos);
        try {
            out.writeObject(ctx);
            boolean isDriverAvailable = (ctx.getSelectedDriver() != null);
            out.writeBoolean(isDriverAvailable);
            if (isDriverAvailable) {
                out.writeUTF(ctx.getSelectedDriver().getClass().getName());
            }
        } finally {
            out.flush();
            out.close();
            baos.close();
        }

        return baos.toByteArray();
    }

    private QueryContext readContext(byte[] bytes, GrillDriver driver) throws IOException, ClassNotFoundException {
        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        ObjectInputStream in = new ObjectInputStream(bais);
        QueryContext ctx;
        try {
            ctx = (QueryContext) in.readObject();
            boolean driverAvailable = in.readBoolean();
            if (driverAvailable) {
                String clsName = in.readUTF();
                ctx.setSelectedDriver(driver);
            }
        } finally {
            in.close();
            bais.close();
        }
        return ctx;
    }

    private void createPartitionedTable(String tableName, int partitions) throws Exception {
        conf.setBoolean(GrillConfConstants.GRILL_ADD_INSERT_OVEWRITE, false);
        conf.setBoolean(GrillConfConstants.GRILL_PERSISTENT_RESULT_SET, false);

        QueryContext ctx = new QueryContext("CREATE EXTERNAL TABLE IF NOT EXISTS " + tableName
                + " (ID STRING) PARTITIONED BY (DT STRING, ET STRING)", null, conf);

        driver.execute(ctx);
        Assert.assertEquals(0, driver.getHiveHandleSize());

        File dataDir = new File("target/partdata");
        dataDir.mkdir();

        // Add partitions
        for (int i = 0; i < partitions; i++) {
            // Create partition paths
            File tableDir = new File(dataDir, tableName);
            tableDir.mkdir();
            File partDir = new File(tableDir, "p" + i);
            partDir.mkdir();

            // Create data file
            File data = new File(partDir, "data.txt");
            FileUtils.writeLines(data, Arrays.asList("one", "two", "three", "four", "five"));

            System.out.println("@@ Adding partition " + i);
            QueryContext partCtx = new QueryContext("ALTER TABLE " + tableName
                    + " ADD IF NOT EXISTS PARTITION (DT='p" + i + "', ET='1') LOCATION '" + partDir.getPath() + "'",
                    null, conf);
            driver.execute(partCtx);
        }
    }

    @Test
    public void testPartitionInQueryPlan() throws Exception {
        // Create tables with 10 & 1 partitions respectively
        createPartitionedTable("table_1", 10);
        createPartitionedTable("table_2", 1);

        // Query should select 5 partitions of table 1 and 1 partitions of table 2
        String explainQuery = "SELECT table_1.ID  "
                + "FROM table_1 LEFT OUTER JOIN table_2 ON table_1.ID = table_2.ID AND table_2.DT='p0' "
                + "WHERE table_1.DT='p0' OR table_1.DT='p1' OR table_1.DT='p2' OR table_1.DT='p3' OR table_1.DT='p4' "
                + "AND table_1.ET='1'";

        DriverQueryPlan plan = driver.explain(explainQuery, conf);

        Assert.assertEquals(0, driver.getHiveHandleSize());
        System.out.println("@@ partitions" + plan.getPartitions());

        Assert.assertEquals(plan.getPartitions().size(), 2);

        String dbName = TestRemoteHiveDriver.class.getSimpleName().toLowerCase();
        Assert.assertTrue(plan.getPartitions().containsKey(dbName + ".table_1"));
        Assert.assertEquals(plan.getPartitions().get(dbName + ".table_1").size(), 5);

        Assert.assertTrue(plan.getPartitions().containsKey(dbName + ".table_2"));
        Assert.assertEquals(plan.getPartitions().get(dbName + ".table_2").size(), 1);

        FileUtils.deleteDirectory(new File("target/partdata"));
    }
}