org.apache.hadoop.hive.ql.parse.MergeSemanticAnalyzer.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.parse.MergeSemanticAnalyzer.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.hive.ql.parse;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.antlr.runtime.TokenRewriteStream;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.Warehouse;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.metastore.api.MetaException;
import org.apache.hadoop.hive.ql.Context;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.QueryState;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.HiveUtils;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.session.SessionState;

/**
 * A subclass of the {@link org.apache.hadoop.hive.ql.parse.SemanticAnalyzer} that just handles
 * merge statements. It works by rewriting the updates and deletes into insert statements (since
 * they are actually inserts) and then doing some patch up to make them work as merges instead.
 */
public class MergeSemanticAnalyzer extends RewriteSemanticAnalyzer {
    MergeSemanticAnalyzer(QueryState queryState) throws SemanticException {
        super(queryState);
    }

    @Override
    public void analyze(ASTNode tree) throws SemanticException {
        if (tree.getToken().getType() != HiveParser.TOK_MERGE) {
            throw new RuntimeException("Asked to parse token " + tree.getName() + " in " + "MergeSemanticAnalyzer");
        }
        analyzeMerge(tree);
    }

    private static final String INDENT = "  ";

    private IdentifierQuoter quotedIdenfierHelper;

    /**
     * This allows us to take an arbitrary ASTNode and turn it back into SQL that produced it.
     * Since HiveLexer.g is written such that it strips away any ` (back ticks) around
     * quoted identifiers we need to add those back to generated SQL.
     * Additionally, the parser only produces tokens of type Identifier and never
     * QuotedIdentifier (HIVE-6013).  So here we just quote all identifiers.
     * (') around String literals are retained w/o issues
     */
    private static class IdentifierQuoter {
        private final TokenRewriteStream trs;
        private final IdentityHashMap<ASTNode, ASTNode> visitedNodes = new IdentityHashMap<>();

        IdentifierQuoter(TokenRewriteStream trs) {
            this.trs = trs;
            if (trs == null) {
                throw new IllegalArgumentException("Must have a TokenRewriteStream");
            }
        }

        private void visit(ASTNode n) {
            if (n.getType() == HiveParser.Identifier) {
                if (visitedNodes.containsKey(n)) {
                    /**
                     * Since we are modifying the stream, it's not idempotent.  Ideally, the caller would take
                     * care to only quote Identifiers in each subtree once, but this makes it safe
                     */
                    return;
                }
                visitedNodes.put(n, n);
                trs.insertBefore(n.getToken(), "`");
                trs.insertAfter(n.getToken(), "`");
            }
            if (n.getChildCount() <= 0) {
                return;
            }
            for (Node c : n.getChildren()) {
                visit((ASTNode) c);
            }
        }
    }

    /**
     * This allows us to take an arbitrary ASTNode and turn it back into SQL that produced it without
     * needing to understand what it is (except for QuotedIdentifiers).
     */
    private String getMatchedText(ASTNode n) {
        quotedIdenfierHelper.visit(n);
        return ctx.getTokenRewriteStream().toString(n.getTokenStartIndex(), n.getTokenStopIndex() + 1).trim();
    }

    /**
     * Here we take a Merge statement AST and generate a semantically equivalent multi-insert
     * statement to execute.  Each Insert leg represents a single WHEN clause.  As much as possible,
     * the new SQL statement is made to look like the input SQL statement so that it's easier to map
     * Query Compiler errors from generated SQL to original one this way.
     * The generated SQL is a complete representation of the original input for the same reason.
     * In many places SemanticAnalyzer throws exceptions that contain (line, position) coordinates.
     * If generated SQL doesn't have everything and is patched up later, these coordinates point to
     * the wrong place.
     *
     * @throws SemanticException
     */
    private void analyzeMerge(ASTNode tree) throws SemanticException {
        quotedIdenfierHelper = new IdentifierQuoter(ctx.getTokenRewriteStream());
        /*
         * See org.apache.hadoop.hive.ql.parse.TestMergeStatement for some examples of the merge AST
          For example, given:
          MERGE INTO acidTbl USING nonAcidPart2 source ON acidTbl.a = source.a2
          WHEN MATCHED THEN UPDATE SET b = source.b2
          WHEN NOT MATCHED THEN INSERT VALUES (source.a2, source.b2)
            
          We get AST like this:
          "(tok_merge " +
            "(tok_tabname acidtbl) (tok_tabref (tok_tabname nonacidpart2) source) " +
            "(= (. (tok_table_or_col acidtbl) a) (. (tok_table_or_col source) a2)) " +
            "(tok_matched " +
            "(tok_update " +
            "(tok_set_columns_clause (= (tok_table_or_col b) (. (tok_table_or_col source) b2))))) " +
            "(tok_not_matched " +
            "tok_insert " +
            "(tok_value_row (. (tok_table_or_col source) a2) (. (tok_table_or_col source) b2))))");
            
            And need to produce a multi-insert like this to execute:
            FROM acidTbl RIGHT OUTER JOIN nonAcidPart2 ON acidTbl.a = source.a2
            INSERT INTO TABLE acidTbl SELECT nonAcidPart2.a2, nonAcidPart2.b2 WHERE acidTbl.a IS null
            INSERT INTO TABLE acidTbl SELECT target.ROW__ID, nonAcidPart2.a2, nonAcidPart2.b2
            WHERE nonAcidPart2.a2=acidTbl.a SORT BY acidTbl.ROW__ID
        */
        /*todo: we need some sort of validation phase over original AST to make things user friendly; for example, if
         original command refers to a column that doesn't exist, this will be caught when processing the rewritten query but
         the errors will point at locations that the user can't map to anything
         - VALUES clause must have the same number of values as target table (including partition cols).  Part cols go last
         in Select clause of Insert as Select
         todo: do we care to preserve comments in original SQL?
         todo: check if identifiers are propertly escaped/quoted in the generated SQL - it's currently inconsistent
          Look at UnparseTranslator.addIdentifierTranslation() - it does unescape + unparse...
         todo: consider "WHEN NOT MATCHED BY SOURCE THEN UPDATE SET TargetTable.Col1 = SourceTable.Col1 "; what happens when
         source is empty?  This should be a runtime error - maybe not the outer side of ROJ is empty => the join produces 0
         rows. If supporting WHEN NOT MATCHED BY SOURCE, then this should be a runtime error
        */
        ASTNode target = (ASTNode) tree.getChild(0);
        ASTNode source = (ASTNode) tree.getChild(1);
        String targetName = getSimpleTableName(target);
        String sourceName = getSimpleTableName(source);
        ASTNode onClause = (ASTNode) tree.getChild(2);
        String onClauseAsText = getMatchedText(onClause);

        int whenClauseBegins = 3;
        boolean hasHint = false;
        // query hint
        ASTNode qHint = (ASTNode) tree.getChild(3);
        if (qHint.getType() == HiveParser.QUERY_HINT) {
            hasHint = true;
            whenClauseBegins++;
        }
        Table targetTable = getTargetTable(target);
        validateTargetTable(targetTable);
        List<ASTNode> whenClauses = findWhenClauses(tree, whenClauseBegins);

        StringBuilder rewrittenQueryStr = new StringBuilder("FROM\n");

        rewrittenQueryStr.append(INDENT).append(getFullTableNameForSQL(target));
        if (isAliased(target)) {
            rewrittenQueryStr.append(" ").append(targetName);
        }
        rewrittenQueryStr.append('\n');
        rewrittenQueryStr.append(INDENT).append(chooseJoinType(whenClauses)).append("\n");
        if (source.getType() == HiveParser.TOK_SUBQUERY) {
            //this includes the mandatory alias
            rewrittenQueryStr.append(INDENT).append(getMatchedText(source));
        } else {
            rewrittenQueryStr.append(INDENT).append(getFullTableNameForSQL(source));
            if (isAliased(source)) {
                rewrittenQueryStr.append(" ").append(sourceName);
            }
        }
        rewrittenQueryStr.append('\n');
        rewrittenQueryStr.append(INDENT).append("ON ").append(onClauseAsText).append('\n');

        // Add the hint if any
        String hintStr = null;
        if (hasHint) {
            hintStr = " /*+ " + qHint.getText() + " */ ";
        }

        /**
         * We allow at most 2 WHEN MATCHED clause, in which case 1 must be Update the other Delete
         * If we have both update and delete, the 1st one (in SQL code) must have "AND <extra predicate>"
         * so that the 2nd can ensure not to process the same rows.
         * Update and Delete may be in any order.  (Insert is always last)
         */
        String extraPredicate = null;
        int numWhenMatchedUpdateClauses = 0, numWhenMatchedDeleteClauses = 0;
        int numInsertClauses = 0;
        boolean hintProcessed = false;
        for (ASTNode whenClause : whenClauses) {
            switch (getWhenClauseOperation(whenClause).getType()) {
            case HiveParser.TOK_INSERT:
                numInsertClauses++;
                handleInsert(whenClause, rewrittenQueryStr, target, onClause, targetTable, targetName,
                        onClauseAsText, hintProcessed ? null : hintStr);
                hintProcessed = true;
                break;
            case HiveParser.TOK_UPDATE:
                numWhenMatchedUpdateClauses++;
                String s = handleUpdate(whenClause, rewrittenQueryStr, target, onClauseAsText, targetTable,
                        extraPredicate, hintProcessed ? null : hintStr);
                hintProcessed = true;
                if (numWhenMatchedUpdateClauses + numWhenMatchedDeleteClauses == 1) {
                    extraPredicate = s; //i.e. it's the 1st WHEN MATCHED
                }
                break;
            case HiveParser.TOK_DELETE:
                numWhenMatchedDeleteClauses++;
                String s1 = handleDelete(whenClause, rewrittenQueryStr, target, onClauseAsText, targetTable,
                        extraPredicate, hintProcessed ? null : hintStr);
                hintProcessed = true;
                if (numWhenMatchedUpdateClauses + numWhenMatchedDeleteClauses == 1) {
                    extraPredicate = s1; //i.e. it's the 1st WHEN MATCHED
                }
                break;
            default:
                throw new IllegalStateException(
                        "Unexpected WHEN clause type: " + whenClause.getType() + addParseInfo(whenClause));
            }
            if (numWhenMatchedDeleteClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_DELETE, ctx.getCmd());
            }
            if (numWhenMatchedUpdateClauses > 1) {
                throw new SemanticException(ErrorMsg.MERGE_TOO_MANY_UPDATE, ctx.getCmd());
            }
            assert numInsertClauses < 2 : "too many Insert clauses";
        }
        if (numWhenMatchedDeleteClauses + numWhenMatchedUpdateClauses == 2 && extraPredicate == null) {
            throw new SemanticException(ErrorMsg.MERGE_PREDIACTE_REQUIRED, ctx.getCmd());
        }

        boolean validating = handleCardinalityViolation(rewrittenQueryStr, target, onClauseAsText, targetTable,
                numWhenMatchedDeleteClauses == 0 && numWhenMatchedUpdateClauses == 0);
        ReparseResult rr = parseRewrittenQuery(rewrittenQueryStr, ctx.getCmd());
        Context rewrittenCtx = rr.rewrittenCtx;
        ASTNode rewrittenTree = rr.rewrittenTree;
        rewrittenCtx.setOperation(Context.Operation.MERGE);

        //set dest name mapping on new context; 1st chid is TOK_FROM
        for (int insClauseIdx = 1, whenClauseIdx = 0; insClauseIdx < rewrittenTree.getChildCount()
                - (validating ? 1 : 0/*skip cardinality violation clause*/); insClauseIdx++, whenClauseIdx++) {
            //we've added Insert clauses in order or WHEN items in whenClauses
            switch (getWhenClauseOperation(whenClauses.get(whenClauseIdx)).getType()) {
            case HiveParser.TOK_INSERT:
                rewrittenCtx.addDestNamePrefix(insClauseIdx, Context.DestClausePrefix.INSERT);
                break;
            case HiveParser.TOK_UPDATE:
                rewrittenCtx.addDestNamePrefix(insClauseIdx, Context.DestClausePrefix.UPDATE);
                break;
            case HiveParser.TOK_DELETE:
                rewrittenCtx.addDestNamePrefix(insClauseIdx, Context.DestClausePrefix.DELETE);
                break;
            default:
                assert false;
            }
        }
        if (validating) {
            //here means the last branch of the multi-insert is Cardinality Validation
            rewrittenCtx.addDestNamePrefix(rewrittenTree.getChildCount() - 1, Context.DestClausePrefix.INSERT);
        }

        try {
            useSuper = true;
            super.analyze(rewrittenTree, rewrittenCtx);
        } finally {
            useSuper = false;
        }
        updateOutputs(targetTable);
    }

    /**
     * If there is no WHEN NOT MATCHED THEN INSERT, we don't outer join.
     */
    private String chooseJoinType(List<ASTNode> whenClauses) {
        for (ASTNode whenClause : whenClauses) {
            if (getWhenClauseOperation(whenClause).getType() == HiveParser.TOK_INSERT) {
                return "RIGHT OUTER JOIN";
            }
        }
        return "INNER JOIN";
    }

    /**
     * Per SQL Spec ISO/IEC 9075-2:2011(E) Section 14.2 under "General Rules" Item 6/Subitem a/Subitem 2/Subitem B,
     * an error should be raised if > 1 row of "source" matches the same row in "target".
     * This should not affect the runtime of the query as it's running in parallel with other
     * branches of the multi-insert.  It won't actually write any data to merge_tmp_table since the
     * cardinality_violation() UDF throws an error whenever it's called killing the query
     * @return true if another Insert clause was added
     */
    private boolean handleCardinalityViolation(StringBuilder rewrittenQueryStr, ASTNode target,
            String onClauseAsString, Table targetTable, boolean onlyHaveWhenNotMatchedClause)
            throws SemanticException {
        if (!conf.getBoolVar(HiveConf.ConfVars.MERGE_CARDINALITY_VIOLATION_CHECK)) {
            LOG.info("Merge statement cardinality violation check is disabled: "
                    + HiveConf.ConfVars.MERGE_CARDINALITY_VIOLATION_CHECK.varname);
            return false;
        }
        if (onlyHaveWhenNotMatchedClause) {
            //if no update or delete in Merge, there is no need to to do cardinality check
            return false;
        }
        //this is a tmp table and thus Session scoped and acid requires SQL statement to be serial in a
        // given session, i.e. the name can be fixed across all invocations
        String tableName = "merge_tmp_table";
        rewrittenQueryStr.append("\nINSERT INTO ").append(tableName).append("\n  SELECT cardinality_violation(")
                .append(getSimpleTableName(target)).append(".ROW__ID");
        addPartitionColsToSelect(targetTable.getPartCols(), rewrittenQueryStr, target);

        rewrittenQueryStr.append(")\n WHERE ").append(onClauseAsString).append(" GROUP BY ")
                .append(getSimpleTableName(target)).append(".ROW__ID");

        addPartitionColsToSelect(targetTable.getPartCols(), rewrittenQueryStr, target);

        rewrittenQueryStr.append(" HAVING count(*) > 1");
        //say table T has partition p, we are generating
        //select cardinality_violation(ROW_ID, p) WHERE ... GROUP BY ROW__ID, p
        //the Group By args are passed to cardinality_violation to add the violating value to the error msg
        try {
            if (null == db.getTable(tableName, false)) {
                StorageFormat format = new StorageFormat(conf);
                format.processStorageFormat("TextFile");
                Table table = db.newTable(tableName);
                table.setSerializationLib(format.getSerde());
                List<FieldSchema> fields = new ArrayList<FieldSchema>();
                fields.add(new FieldSchema("val", "int", null));
                table.setFields(fields);
                table.setDataLocation(
                        Warehouse.getDnsPath(new Path(SessionState.get().getTempTableSpace(), tableName), conf));
                table.getTTable().setTemporary(true);
                table.setStoredAsSubDirectories(false);
                table.setInputFormatClass(format.getInputFormat());
                table.setOutputFormatClass(format.getOutputFormat());
                db.createTable(table, true);
            }
        } catch (HiveException | MetaException e) {
            throw new SemanticException(e.getMessage(), e);
        }
        return true;
    }

    /**
     * @param onClauseAsString - because there is no clone() and we need to use in multiple places
     * @param deleteExtraPredicate - see notes at caller
     */
    private String handleUpdate(ASTNode whenMatchedUpdateClause, StringBuilder rewrittenQueryStr, ASTNode target,
            String onClauseAsString, Table targetTable, String deleteExtraPredicate, String hintStr)
            throws SemanticException {
        assert whenMatchedUpdateClause.getType() == HiveParser.TOK_MATCHED;
        assert getWhenClauseOperation(whenMatchedUpdateClause).getType() == HiveParser.TOK_UPDATE;
        String targetName = getSimpleTableName(target);
        rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target));
        addPartitionColsToInsert(targetTable.getPartCols(), rewrittenQueryStr);
        rewrittenQueryStr.append("    -- update clause\n SELECT ");
        if (hintStr != null) {
            rewrittenQueryStr.append(hintStr);
        }
        rewrittenQueryStr.append(targetName).append(".ROW__ID");

        ASTNode setClause = (ASTNode) getWhenClauseOperation(whenMatchedUpdateClause).getChild(0);
        //columns being updated -> update expressions; "setRCols" (last param) is null because we use actual expressions
        //before reparsing, i.e. they are known to SemanticAnalyzer logic
        Map<String, ASTNode> setColsExprs = collectSetColumnsAndExpressions(setClause, null, targetTable);
        //if target table has cols c1,c2,c3 and p1 partition col and we had "SET c2 = 5, c1 = current_date()" we want to end
        //up with
        //insert into target (p1) select current_date(), 5, c3, p1 where ....
        //since we take the RHS of set exactly as it was in Input, we don't need to deal with quoting/escaping column/table
        //names
        List<FieldSchema> nonPartCols = targetTable.getCols();
        for (FieldSchema fs : nonPartCols) {
            rewrittenQueryStr.append(", ");
            String name = fs.getName();
            if (setColsExprs.containsKey(name)) {
                String rhsExp = getMatchedText(setColsExprs.get(name));
                //"set a=5, b=8" - rhsExp picks up the next char (e.g. ',') from the token stream
                switch (rhsExp.charAt(rhsExp.length() - 1)) {
                case ',':
                case '\n':
                    rhsExp = rhsExp.substring(0, rhsExp.length() - 1);
                    break;
                default:
                    //do nothing
                }
                rewrittenQueryStr.append(rhsExp);
            } else {
                rewrittenQueryStr.append(getSimpleTableName(target)).append(".")
                        .append(HiveUtils.unparseIdentifier(name, this.conf));
            }
        }
        addPartitionColsToSelect(targetTable.getPartCols(), rewrittenQueryStr, target);
        rewrittenQueryStr.append("\n   WHERE ").append(onClauseAsString);
        String extraPredicate = getWhenClausePredicate(whenMatchedUpdateClause);
        if (extraPredicate != null) {
            //we have WHEN MATCHED AND <boolean expr> THEN DELETE
            rewrittenQueryStr.append(" AND ").append(extraPredicate);
        }
        if (deleteExtraPredicate != null) {
            rewrittenQueryStr.append(" AND NOT(").append(deleteExtraPredicate).append(")");
        }
        rewrittenQueryStr.append("\n SORT BY ");
        rewrittenQueryStr.append(targetName).append(".ROW__ID \n");

        setUpAccessControlInfoForUpdate(targetTable, setColsExprs);
        //we don't deal with columns on RHS of SET expression since the whole expr is part of the
        //rewritten SQL statement and is thus handled by SemanticAnalzyer.  Nor do we have to
        //figure which cols on RHS are from source and which from target

        return extraPredicate;
    }

    /**
     * @param onClauseAsString - because there is no clone() and we need to use in multiple places
     * @param updateExtraPredicate - see notes at caller
     */
    private String handleDelete(ASTNode whenMatchedDeleteClause, StringBuilder rewrittenQueryStr, ASTNode target,
            String onClauseAsString, Table targetTable, String updateExtraPredicate, String hintStr)
            throws SemanticException {
        assert whenMatchedDeleteClause.getType() == HiveParser.TOK_MATCHED;
        assert getWhenClauseOperation(whenMatchedDeleteClause).getType() == HiveParser.TOK_DELETE;
        List<FieldSchema> partCols = targetTable.getPartCols();
        String targetName = getSimpleTableName(target);
        rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target));
        addPartitionColsToInsert(partCols, rewrittenQueryStr);

        rewrittenQueryStr.append("    -- delete clause\n SELECT ");
        if (hintStr != null) {
            rewrittenQueryStr.append(hintStr);
        }
        rewrittenQueryStr.append(targetName).append(".ROW__ID ");
        addPartitionColsToSelect(partCols, rewrittenQueryStr, target);
        rewrittenQueryStr.append("\n   WHERE ").append(onClauseAsString);
        String extraPredicate = getWhenClausePredicate(whenMatchedDeleteClause);
        if (extraPredicate != null) {
            //we have WHEN MATCHED AND <boolean expr> THEN DELETE
            rewrittenQueryStr.append(" AND ").append(extraPredicate);
        }
        if (updateExtraPredicate != null) {
            rewrittenQueryStr.append(" AND NOT(").append(updateExtraPredicate).append(")");
        }
        rewrittenQueryStr.append("\n SORT BY ");
        rewrittenQueryStr.append(targetName).append(".ROW__ID \n");
        return extraPredicate;
    }

    private static String addParseInfo(ASTNode n) {
        return " at " + ErrorMsg.renderPosition(n);
    }

    private boolean isAliased(ASTNode n) {
        switch (n.getType()) {
        case HiveParser.TOK_TABREF:
            return findTabRefIdxs(n)[0] != 0;
        case HiveParser.TOK_TABNAME:
            return false;
        case HiveParser.TOK_SUBQUERY:
            assert n.getChildCount() > 1 : "Expected Derived Table to be aliased";
            return true;
        default:
            throw raiseWrongType("TOK_TABREF|TOK_TABNAME", n);
        }
    }

    /**
     * Collect WHEN clauses from Merge statement AST.
     */
    private List<ASTNode> findWhenClauses(ASTNode tree, int start) throws SemanticException {
        assert tree.getType() == HiveParser.TOK_MERGE;
        List<ASTNode> whenClauses = new ArrayList<>();
        for (int idx = start; idx < tree.getChildCount(); idx++) {
            ASTNode whenClause = (ASTNode) tree.getChild(idx);
            assert whenClause.getType() == HiveParser.TOK_MATCHED
                    || whenClause.getType() == HiveParser.TOK_NOT_MATCHED : "Unexpected node type found: "
                            + whenClause.getType() + addParseInfo(whenClause);
            whenClauses.add(whenClause);
        }
        if (whenClauses.size() <= 0) {
            //Futureproofing: the parser will actually not allow this
            throw new SemanticException("Must have at least 1 WHEN clause in MERGE statement");
        }
        return whenClauses;
    }

    private ASTNode getWhenClauseOperation(ASTNode whenClause) {
        if (!(whenClause.getType() == HiveParser.TOK_MATCHED
                || whenClause.getType() == HiveParser.TOK_NOT_MATCHED)) {
            throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause);
        }
        return (ASTNode) whenClause.getChild(0);
    }

    /**
     * Returns the <boolean predicate> as in WHEN MATCHED AND <boolean predicate> THEN...
     * @return may be null
     */
    private String getWhenClausePredicate(ASTNode whenClause) {
        if (!(whenClause.getType() == HiveParser.TOK_MATCHED
                || whenClause.getType() == HiveParser.TOK_NOT_MATCHED)) {
            throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause);
        }
        if (whenClause.getChildCount() == 2) {
            return getMatchedText((ASTNode) whenClause.getChild(1));
        }
        return null;
    }

    /**
     * Generates the Insert leg of the multi-insert SQL to represent WHEN NOT MATCHED THEN INSERT clause.
     * @param targetTableNameInSourceQuery - simple name/alias
     * @throws SemanticException
     */
    private void handleInsert(ASTNode whenNotMatchedClause, StringBuilder rewrittenQueryStr, ASTNode target,
            ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery, String onClauseAsString,
            String hintStr) throws SemanticException {
        ASTNode whenClauseOperation = getWhenClauseOperation(whenNotMatchedClause);
        assert whenNotMatchedClause.getType() == HiveParser.TOK_NOT_MATCHED;
        assert whenClauseOperation.getType() == HiveParser.TOK_INSERT;

        // identify the node that contains the values to insert and the optional column list node
        ArrayList<Node> children = whenClauseOperation.getChildren();
        ASTNode valuesNode = (ASTNode) children.stream()
                .filter(n -> ((ASTNode) n).getType() == HiveParser.TOK_FUNCTION).findFirst().get();
        ASTNode columnListNode = (ASTNode) children.stream()
                .filter(n -> ((ASTNode) n).getType() == HiveParser.TOK_TABCOLNAME).findFirst().orElse(null);

        // if column list is specified, then it has to have the same number of elements as the values
        // valuesNode has a child for struct, the rest are the columns
        if (columnListNode != null && columnListNode.getChildCount() != (valuesNode.getChildCount() - 1)) {
            throw new SemanticException(
                    String.format("Column schema must have the same length as values (%d vs %d)",
                            columnListNode.getChildCount(), valuesNode.getChildCount() - 1));
        }

        rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target));
        if (columnListNode != null) {
            rewrittenQueryStr.append(' ').append(getMatchedText(columnListNode));
        }
        addPartitionColsToInsert(targetTable.getPartCols(), rewrittenQueryStr);

        rewrittenQueryStr.append("    -- insert clause\n  SELECT ");
        if (hintStr != null) {
            rewrittenQueryStr.append(hintStr);
        }

        OnClauseAnalyzer oca = new OnClauseAnalyzer(onClause, targetTable, targetTableNameInSourceQuery, conf,
                onClauseAsString);
        oca.analyze();

        String valuesClause = getMatchedText(valuesNode);
        valuesClause = valuesClause.substring(1, valuesClause.length() - 1); //strip '(' and ')'
        valuesClause = replaceDefaultKeywordForMerge(valuesClause, targetTable, columnListNode);
        rewrittenQueryStr.append(valuesClause).append("\n   WHERE ").append(oca.getPredicate());

        String extraPredicate = getWhenClausePredicate(whenNotMatchedClause);
        if (extraPredicate != null) {
            //we have WHEN NOT MATCHED AND <boolean expr> THEN INSERT
            rewrittenQueryStr.append(" AND ").append(getMatchedText(((ASTNode) whenNotMatchedClause.getChild(1))))
                    .append('\n');
        }
    }

    private String replaceDefaultKeywordForMerge(String valueClause, Table table, ASTNode columnListNode)
            throws SemanticException {
        if (!valueClause.toLowerCase().contains("`default`")) {
            return valueClause;
        }

        Map<String, String> colNameToDefaultConstraint = getColNameToDefaultValueMap(table);
        String[] values = valueClause.trim().split(",");
        String[] replacedValues = new String[values.length];

        // the list of the column names may be set in the query
        String[] columnNames = columnListNode == null
                ? table.getAllCols().stream().map(f -> f.getName()).toArray(size -> new String[size])
                : columnListNode.getChildren().stream().map(n -> ((ASTNode) n).toString())
                        .toArray(size -> new String[size]);

        for (int i = 0; i < values.length; i++) {
            if (values[i].trim().toLowerCase().equals("`default`")) {
                replacedValues[i] = MapUtils.getString(colNameToDefaultConstraint, columnNames[i], "null");
            } else {
                replacedValues[i] = values[i];
            }
        }
        return StringUtils.join(replacedValues, ',');
    }

    /**
     * Suppose the input Merge statement has ON target.a = source.b and c = d.  Assume, that 'c' is from
     * target table and 'd' is from source expression.  In order to properly
     * generate the Insert for WHEN NOT MATCHED THEN INSERT, we need to make sure that the Where
     * clause of this Insert contains "target.a is null and target.c is null"  This ensures that this
     * Insert leg does not receive any rows that are processed by Insert corresponding to
     * WHEN MATCHED THEN ... clauses.  (Implicit in this is a mini resolver that figures out if an
     * unqualified column is part of the target table.  We can get away with this simple logic because
     * we know that target is always a table (as opposed to some derived table).
     * The job of this class is to generate this predicate.
     *
     * Note that is this predicate cannot simply be NOT(on-clause-expr).  IF on-clause-expr evaluates
     * to Unknown, it will be treated as False in the WHEN MATCHED Inserts but NOT(Unknown) = Unknown,
     * and so it will be False for WHEN NOT MATCHED Insert...
     */
    private static final class OnClauseAnalyzer {
        private final ASTNode onClause;
        private final Map<String, List<String>> table2column = new HashMap<>();
        private final List<String> unresolvedColumns = new ArrayList<>();
        private final List<FieldSchema> allTargetTableColumns = new ArrayList<>();
        private final Set<String> tableNamesFound = new HashSet<>();
        private final String targetTableNameInSourceQuery;
        private final HiveConf conf;
        private final String onClauseAsString;

        /**
         * @param targetTableNameInSourceQuery alias or simple name
         */
        OnClauseAnalyzer(ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery, HiveConf conf,
                String onClauseAsString) {
            this.onClause = onClause;
            allTargetTableColumns.addAll(targetTable.getCols());
            allTargetTableColumns.addAll(targetTable.getPartCols());
            this.targetTableNameInSourceQuery = unescapeIdentifier(targetTableNameInSourceQuery);
            this.conf = conf;
            this.onClauseAsString = onClauseAsString;
        }

        /**
         * Finds all columns and groups by table ref (if there is one).
         */
        private void visit(ASTNode n) {
            if (n.getType() == HiveParser.TOK_TABLE_OR_COL) {
                ASTNode parent = (ASTNode) n.getParent();
                if (parent != null && parent.getType() == HiveParser.DOT) {
                    //the ref must be a table, so look for column name as right child of DOT
                    if (parent.getParent() != null && parent.getParent().getType() == HiveParser.DOT) {
                        //I don't think this can happen... but just in case
                        throw new IllegalArgumentException(
                                "Found unexpected db.table.col reference in " + onClauseAsString);
                    }
                    addColumn2Table(n.getChild(0).getText(), parent.getChild(1).getText());
                } else {
                    //must be just a column name
                    unresolvedColumns.add(n.getChild(0).getText());
                }
            }
            if (n.getChildCount() == 0) {
                return;
            }
            for (Node child : n.getChildren()) {
                visit((ASTNode) child);
            }
        }

        private void analyze() {
            visit(onClause);
            if (tableNamesFound.size() > 2) {
                throw new IllegalArgumentException(
                        "Found > 2 table refs in ON clause.  Found " + tableNamesFound + " in " + onClauseAsString);
            }
            handleUnresolvedColumns();
            if (tableNamesFound.size() > 2) {
                throw new IllegalArgumentException("Found > 2 table refs in ON clause (incl unresolved).  "
                        + "Found " + tableNamesFound + " in " + onClauseAsString);
            }
        }

        /**
         * Find those that belong to target table.
         */
        private void handleUnresolvedColumns() {
            if (unresolvedColumns.isEmpty()) {
                return;
            }
            for (String c : unresolvedColumns) {
                for (FieldSchema fs : allTargetTableColumns) {
                    if (c.equalsIgnoreCase(fs.getName())) {
                        //c belongs to target table; strictly speaking there maybe an ambiguous ref but
                        //this will be caught later when multi-insert is parsed
                        addColumn2Table(targetTableNameInSourceQuery.toLowerCase(), c);
                        break;
                    }
                }
            }
        }

        private void addColumn2Table(String tableName, String columnName) {
            tableName = tableName.toLowerCase(); //normalize name for mapping
            tableNamesFound.add(tableName);
            List<String> cols = table2column.get(tableName);
            if (cols == null) {
                cols = new ArrayList<>();
                table2column.put(tableName, cols);
            }
            //we want to preserve 'columnName' as it was in original input query so that rewrite
            //looks as much as possible like original query
            cols.add(columnName);
        }

        /**
         * Now generate the predicate for Where clause.
         */
        private String getPredicate() {
            //normilize table name for mapping
            List<String> targetCols = table2column.get(targetTableNameInSourceQuery.toLowerCase());
            if (targetCols == null) {
                /*e.g. ON source.t=1
                * this is not strictly speaking invalid but it does ensure that all columns from target
                * table are all NULL for every row.  This would make any WHEN MATCHED clause invalid since
                * we don't have a ROW__ID.  The WHEN NOT MATCHED could be meaningful but it's just data from
                * source satisfying source.t=1...  not worth the effort to support this*/
                throw new IllegalArgumentException(ErrorMsg.INVALID_TABLE_IN_ON_CLAUSE_OF_MERGE
                        .format(targetTableNameInSourceQuery, onClauseAsString));
            }
            StringBuilder sb = new StringBuilder();
            for (String col : targetCols) {
                if (sb.length() > 0) {
                    sb.append(" AND ");
                }
                //but preserve table name in SQL
                sb.append(HiveUtils.unparseIdentifier(targetTableNameInSourceQuery, conf)).append(".")
                        .append(HiveUtils.unparseIdentifier(col, conf)).append(" IS NULL");
            }
            return sb.toString();
        }
    }
}