org.apache.spark.sql.parser.SemanticAnalyzer.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.spark.sql.parser.SemanticAnalyzer.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.parser;

import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.antlr.runtime.tree.Tree;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde.serdeConstants;
import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;

/**
 * SemanticAnalyzer.
 *
 */
public abstract class SemanticAnalyzer {
    public static String charSetString(String charSetName, String charSetString) throws SemanticException {
        try {
            // The character set name starts with a _, so strip that
            charSetName = charSetName.substring(1);
            if (charSetString.charAt(0) == '\'') {
                return new String(unescapeSQLString(charSetString).getBytes(), charSetName);
            } else // hex input is also supported
            {
                assert charSetString.charAt(0) == '0';
                assert charSetString.charAt(1) == 'x';
                charSetString = charSetString.substring(2);

                byte[] bArray = new byte[charSetString.length() / 2];
                int j = 0;
                for (int i = 0; i < charSetString.length(); i += 2) {
                    int val = Character.digit(charSetString.charAt(i), 16) * 16
                            + Character.digit(charSetString.charAt(i + 1), 16);
                    if (val > 127) {
                        val = val - 256;
                    }
                    bArray[j++] = (byte) val;
                }

                String res = new String(bArray, charSetName);
                return res;
            }
        } catch (UnsupportedEncodingException e) {
            throw new SemanticException(e);
        }
    }

    /**
     * Remove the encapsulating "`" pair from the identifier. We allow users to
     * use "`" to escape identifier for table names, column names and aliases, in
     * case that coincide with Hive language keywords.
     */
    public static String unescapeIdentifier(String val) {
        if (val == null) {
            return null;
        }
        if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') {
            val = val.substring(1, val.length() - 1);
        }
        return val;
    }

    /**
     * Converts parsed key/value properties pairs into a map.
     *
     * @param prop ASTNode parent of the key/value pairs
     *
     * @param mapProp property map which receives the mappings
     */
    public static void readProps(ASTNode prop, Map<String, String> mapProp) {

        for (int propChild = 0; propChild < prop.getChildCount(); propChild++) {
            String key = unescapeSQLString(prop.getChild(propChild).getChild(0).getText());
            String value = null;
            if (prop.getChild(propChild).getChild(1) != null) {
                value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText());
            }
            mapProp.put(key, value);
        }
    }

    private static final int[] multiplier = new int[] { 1000, 100, 10, 1 };

    @SuppressWarnings("nls")
    public static String unescapeSQLString(String b) {
        Character enclosure = null;

        // Some of the strings can be passed in as unicode. For example, the
        // delimiter can be passed in as \002 - So, we first check if the
        // string is a unicode number, else go back to the old behavior
        StringBuilder sb = new StringBuilder(b.length());
        for (int i = 0; i < b.length(); i++) {

            char currentChar = b.charAt(i);
            if (enclosure == null) {
                if (currentChar == '\'' || b.charAt(i) == '\"') {
                    enclosure = currentChar;
                }
                // ignore all other chars outside the enclosure
                continue;
            }

            if (enclosure.equals(currentChar)) {
                enclosure = null;
                continue;
            }

            if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') {
                int code = 0;
                int base = i + 2;
                for (int j = 0; j < 4; j++) {
                    int digit = Character.digit(b.charAt(j + base), 16);
                    code += digit * multiplier[j];
                }
                sb.append((char) code);
                i += 5;
                continue;
            }

            if (currentChar == '\\' && (i + 4 < b.length())) {
                char i1 = b.charAt(i + 1);
                char i2 = b.charAt(i + 2);
                char i3 = b.charAt(i + 3);
                if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
                    byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
                    byte[] bValArr = new byte[1];
                    bValArr[0] = bVal;
                    String tmp = new String(bValArr);
                    sb.append(tmp);
                    i += 3;
                    continue;
                }
            }

            if (currentChar == '\\' && (i + 2 < b.length())) {
                char n = b.charAt(i + 1);
                switch (n) {
                case '0':
                    sb.append("\0");
                    break;
                case '\'':
                    sb.append("'");
                    break;
                case '"':
                    sb.append("\"");
                    break;
                case 'b':
                    sb.append("\b");
                    break;
                case 'n':
                    sb.append("\n");
                    break;
                case 'r':
                    sb.append("\r");
                    break;
                case 't':
                    sb.append("\t");
                    break;
                case 'Z':
                    sb.append("\u001A");
                    break;
                case '\\':
                    sb.append("\\");
                    break;
                // The following 2 lines are exactly what MySQL does TODO: why do we do this?
                case '%':
                    sb.append("\\%");
                    break;
                case '_':
                    sb.append("\\_");
                    break;
                default:
                    sb.append(n);
                }
                i++;
            } else {
                sb.append(currentChar);
            }
        }
        return sb.toString();
    }

    /**
     * Get the list of FieldSchema out of the ASTNode.
     */
    public static List<FieldSchema> getColumns(ASTNode ast, boolean lowerCase) throws SemanticException {
        List<FieldSchema> colList = new ArrayList<FieldSchema>();
        int numCh = ast.getChildCount();
        for (int i = 0; i < numCh; i++) {
            FieldSchema col = new FieldSchema();
            ASTNode child = (ASTNode) ast.getChild(i);
            Tree grandChild = child.getChild(0);
            if (grandChild != null) {
                String name = grandChild.getText();
                if (lowerCase) {
                    name = name.toLowerCase();
                }
                // child 0 is the name of the column
                col.setName(unescapeIdentifier(name));
                // child 1 is the type of the column
                ASTNode typeChild = (ASTNode) (child.getChild(1));
                col.setType(getTypeStringFromAST(typeChild));

                // child 2 is the optional comment of the column
                if (child.getChildCount() == 3) {
                    col.setComment(unescapeSQLString(child.getChild(2).getText()));
                }
            }
            colList.add(col);
        }
        return colList;
    }

    protected static String getTypeStringFromAST(ASTNode typeNode) throws SemanticException {
        switch (typeNode.getType()) {
        case SparkSqlParser.TOK_LIST:
            return serdeConstants.LIST_TYPE_NAME + "<" + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">";
        case SparkSqlParser.TOK_MAP:
            return serdeConstants.MAP_TYPE_NAME + "<" + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ","
                    + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">";
        case SparkSqlParser.TOK_STRUCT:
            return getStructTypeStringFromAST(typeNode);
        case SparkSqlParser.TOK_UNIONTYPE:
            return getUnionTypeStringFromAST(typeNode);
        default:
            return getTypeName(typeNode);
        }
    }

    private static String getStructTypeStringFromAST(ASTNode typeNode) throws SemanticException {
        String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<";
        typeNode = (ASTNode) typeNode.getChild(0);
        int children = typeNode.getChildCount();
        if (children <= 0) {
            throw new SemanticException("empty struct not allowed.");
        }
        StringBuilder buffer = new StringBuilder(typeStr);
        for (int i = 0; i < children; i++) {
            ASTNode child = (ASTNode) typeNode.getChild(i);
            buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":");
            buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1)));
            if (i < children - 1) {
                buffer.append(",");
            }
        }

        buffer.append(">");
        return buffer.toString();
    }

    private static String getUnionTypeStringFromAST(ASTNode typeNode) throws SemanticException {
        String typeStr = serdeConstants.UNION_TYPE_NAME + "<";
        typeNode = (ASTNode) typeNode.getChild(0);
        int children = typeNode.getChildCount();
        if (children <= 0) {
            throw new SemanticException("empty union not allowed.");
        }
        StringBuilder buffer = new StringBuilder(typeStr);
        for (int i = 0; i < children; i++) {
            buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i)));
            if (i < children - 1) {
                buffer.append(",");
            }
        }
        buffer.append(">");
        typeStr = buffer.toString();
        return typeStr;
    }

    public static String getAstNodeText(ASTNode tree) {
        return tree.getChildCount() == 0 ? tree.getText()
                : getAstNodeText((ASTNode) tree.getChild(tree.getChildCount() - 1));
    }

    public static String generateErrorMessage(ASTNode ast, String message) {
        StringBuilder sb = new StringBuilder();
        if (ast == null) {
            sb.append(message).append(". Cannot tell the position of null AST.");
            return sb.toString();
        }
        sb.append(ast.getLine());
        sb.append(":");
        sb.append(ast.getCharPositionInLine());
        sb.append(" ");
        sb.append(message);
        sb.append(". Error encountered near token '");
        sb.append(getAstNodeText(ast));
        sb.append("'");
        return sb.toString();
    }

    private static final Map<Integer, String> TokenToTypeName = new HashMap<Integer, String>();

    static {
        TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME);
        TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME);
    }

    public static String getTypeName(ASTNode node) throws SemanticException {
        int token = node.getType();
        String typeName;

        // datetime type isn't currently supported
        if (token == SparkSqlParser.TOK_DATETIME) {
            throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg());
        }

        switch (token) {
        case SparkSqlParser.TOK_CHAR:
            CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node);
            typeName = charTypeInfo.getQualifiedName();
            break;
        case SparkSqlParser.TOK_VARCHAR:
            VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node);
            typeName = varcharTypeInfo.getQualifiedName();
            break;
        case SparkSqlParser.TOK_DECIMAL:
            DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node);
            typeName = decTypeInfo.getQualifiedName();
            break;
        default:
            typeName = TokenToTypeName.get(token);
        }
        return typeName;
    }

    public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException {
        boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE);
        if (testMode) {
            URI uri = new Path(location).toUri();
            String scheme = uri.getScheme();
            String authority = uri.getAuthority();
            String path = uri.getPath();
            if (!path.startsWith("/")) {
                path = (new Path(System.getProperty("test.tmp.dir"), path)).toUri().getPath();
            }
            if (StringUtils.isEmpty(scheme)) {
                scheme = "pfile";
            }
            try {
                uri = new URI(scheme, authority, path, null, null);
            } catch (URISyntaxException e) {
                throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e);
            }
            return uri.toString();
        } else {
            //no-op for non-test mode for now
            return location;
        }
    }
}