com.thinkbiganalytics.hive.service.HiveService.java Source code

Java tutorial

Introduction

Here is the source code for com.thinkbiganalytics.hive.service.HiveService.java

Source

package com.thinkbiganalytics.hive.service;

/*-
 * #%L
 * thinkbig-thrift-proxy-core
 * %%
 * Copyright (C) 2017 ThinkBig Analytics
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */

import com.thinkbiganalytics.discovery.model.DefaultQueryResult;
import com.thinkbiganalytics.discovery.model.DefaultQueryResultColumn;
import com.thinkbiganalytics.discovery.schema.Field;
import com.thinkbiganalytics.discovery.schema.QueryResult;
import com.thinkbiganalytics.discovery.schema.QueryResultColumn;
import com.thinkbiganalytics.discovery.schema.TableSchema;
import com.thinkbiganalytics.discovery.util.ParserHelper;
import com.thinkbiganalytics.hive.util.HiveUtils;
import com.thinkbiganalytics.kerberos.KerberosTicketConfiguration;
import com.thinkbiganalytics.schema.DBSchemaParser;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.inject.Inject;
import javax.sql.DataSource;

/**
 */
public class HiveService {

    private static final Logger log = LoggerFactory.getLogger(HiveService.class);

    @Inject
    @Qualifier("hiveJdbcTemplate")
    private JdbcTemplate jdbcTemplate;

    @Inject
    @Qualifier("kerberosHiveConfiguration")
    private KerberosTicketConfiguration kerberosHiveConfiguration;

    private DBSchemaParser schemaParser = null;

    public DataSource getDataSource() {
        return jdbcTemplate.getDataSource();
    }

    public DBSchemaParser getDBSchemaParser() {
        if (schemaParser == null) {
            schemaParser = new DBSchemaParser(getDataSource(), kerberosHiveConfiguration);
        }
        return schemaParser;
    }

    public List<String> getSchemaNames() {
        return getDBSchemaParser().listSchemas();
    }

    public List<String> getTables(String schema) {
        List<String> tables = getDBSchemaParser().listTables(schema, null);
        return tables;
    }

    /**
     * returns a list of schemanName.TableName
     */
    private List<String> getAllTables() {
        List<String> allTables = new ArrayList<>();
        List<String> schemas = getSchemaNames();
        if (schemas != null) {
            for (String schema : schemas) {
                List<String> tables = getTables(schema);
                if (tables != null) {
                    for (String table : tables) {
                        allTables.add(schema + "." + table);
                    }
                }
            }
        }
        return allTables;
    }

    /**
     * returns a list of all the scheam.tablename for a given schema
     */
    public List<String> getTablesForImpersonatedUser(String schema) {
        QueryResult tables = query("show tables in " + schema);
        return tables.getRows().stream().flatMap(row -> row.entrySet().stream())
                .map(e -> schema + "." + e.getValue().toString()).collect(Collectors.toList());
    }

    /**
     * returns a list of all the schema.tablename
     */
    public List<String> getAllTablesForImpersonatedUser() {
        long start = System.currentTimeMillis();
        List<String> allTables = new ArrayList<>();
        QueryResult result = query("show databases");
        List<Object> databases = result.getRows().stream().flatMap(row -> row.entrySet().stream())
                .map(e -> e.getValue()).collect(Collectors.toList());
        databases.stream().forEach(database -> allTables.addAll(getTablesForImpersonatedUser(database.toString())));
        log.debug("time to get all tables " + (System.currentTimeMillis() - start) + " ms");
        return allTables;
    }

    public boolean testConnection() throws SQLException {
        return ((RefreshableDataSource) this.getDataSource()).testConnection();

    }

    /**
     * returns a list of populated TableSchema objects
     */
    public List<TableSchema> getAllTableSchemas() {
        List<TableSchema> allTables = new ArrayList<>();
        List<String> schemas = getSchemaNames();
        if (schemas != null) {
            for (String schema : schemas) {
                List<String> tables = getTables(schema);
                if (tables != null) {
                    for (String table : tables) {
                        allTables.add(getTableSchema(schema, table));
                    }
                }
            }
        }
        return allTables;
    }

    /**
     * Describes the given Table
     */
    public TableSchema getTableSchema(String schema, String table) {
        return getDBSchemaParser().describeTable(schema, table);
    }

    public List<? extends Field> getFields(String schema, String table) {
        TableSchema tableSchema = getTableSchema(schema, table);
        if (tableSchema != null) {
            return tableSchema.getFields();
        }
        return null;
    }

    public QueryResult browse(String schema, String table, String where, Integer limit) throws DataAccessException {

        if (where == null) {
            where = "";
        }
        String query = "SELECT * from " + HiveUtils.quoteIdentifier(schema, table) + " " + where + " LIMIT "
                + limit;
        return browse(query);
    }

    public QueryResult browse(String query) throws DataAccessException {
        return query(query);

    }

    // TODO: Temporary until we determine how we want to ensure DDL isn't sent through
    private String safeQuery(String query) {
        return "SELECT kylo_.* FROM (" + query + ") kylo_ LIMIT 1000";
    }

    public QueryResult query(String query) throws DataAccessException {
        final DefaultQueryResult queryResult = new DefaultQueryResult(query);
        final List<QueryResultColumn> columns = new ArrayList<>();
        final Map<String, Integer> displayNameMap = new HashMap<>();
        if (query != null && !query.toLowerCase().startsWith("show")) {
            query = safeQuery(query);
        }
        try {
            //  Setting in order to query complex formats like parquet
            jdbcTemplate.execute("set hive.optimize.index.filter=false");
            jdbcTemplate.query(query, new RowMapper<Map<String, Object>>() {
                @Override
                public Map<String, Object> mapRow(ResultSet rs, int rowNum) throws SQLException {
                    if (columns.isEmpty()) {
                        ResultSetMetaData rsMetaData = rs.getMetaData();
                        for (int i = 1; i <= rsMetaData.getColumnCount(); i++) {
                            String colName = rsMetaData.getColumnName(i);
                            DefaultQueryResultColumn column = new DefaultQueryResultColumn();
                            column.setField(rsMetaData.getColumnName(i));
                            String displayName = rsMetaData.getColumnLabel(i);
                            column.setHiveColumnLabel(displayName);
                            //remove the table name if it exists
                            displayName = StringUtils.substringAfterLast(displayName, ".");
                            Integer count = 0;
                            if (displayNameMap.containsKey(displayName)) {
                                count = displayNameMap.get(displayName);
                                count++;
                            }
                            displayNameMap.put(displayName, count);
                            column.setDisplayName(displayName + "" + (count > 0 ? count : ""));

                            column.setTableName(StringUtils.substringAfterLast(rsMetaData.getColumnName(i), "."));
                            column.setDataType(ParserHelper.sqlTypeToHiveType(rsMetaData.getColumnType(i)));
                            columns.add(column);
                        }
                        queryResult.setColumns(columns);
                    }
                    Map<String, Object> row = new LinkedHashMap<>();
                    for (QueryResultColumn column : columns) {
                        row.put(column.getDisplayName(), rs.getObject(column.getHiveColumnLabel()));
                    }
                    queryResult.addRow(row);
                    return row;
                }
            });

        } catch (DataAccessException dae) {
            dae.printStackTrace();
            throw dae;
        }
        return queryResult;

    }

}