org.apache.hadoop.hive.ql.exec.tez.DynamicPartitionPruner.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.exec.tez.DynamicPartitionPruner.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.exec.tez;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;

import org.apache.commons.lang3.mutable.MutableInt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;

/**
 * DynamicPartitionPruner takes a list of assigned partitions at runtime (split
 * generation) and prunes them using events generated during execution of the
 * dag.
 *
 */
public class DynamicPartitionPruner {

    private static final Logger LOG = LoggerFactory.getLogger(DynamicPartitionPruner.class);

    private final InputInitializerContext context;
    private final MapWork work;
    private final JobConf jobConf;

    private final Map<String, List<SourceInfo>> sourceInfoMap = new HashMap<String, List<SourceInfo>>();

    private final BytesWritable writable = new BytesWritable();

    /* Keeps track of all events that need to be processed - irrespective of the source */
    private final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();

    /* Keeps track of vertices from which events are expected */
    private final Set<String> sourcesWaitingForEvents = new HashSet<String>();

    // Stores negative values to count columns. Eventually set to #tasks X #columns after the source vertex completes.
    private final Map<String, MutableInt> numExpectedEventsPerSource = new HashMap<>();
    private final Map<String, MutableInt> numEventsSeenPerSource = new HashMap<>();

    private int sourceInfoCount = 0;

    private final Object endOfEvents = new Object();

    private int totalEventCount = 0;

    public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf)
            throws SerDeException {
        this.context = context;
        this.work = work;
        this.jobConf = jobConf;
        synchronized (this) {
            initialize();
        }
    }

    public void prune() throws SerDeException, IOException, InterruptedException, HiveException {

        synchronized (sourcesWaitingForEvents) {

            if (sourcesWaitingForEvents.isEmpty()) {
                return;
            }

            Set<VertexState> states = Collections.singleton(VertexState.SUCCEEDED);
            for (String source : sourcesWaitingForEvents) {
                // we need to get state transition updates for the vertices that will send
                // events to us. once we have received all events and a vertex has succeeded,
                // we can move to do the pruning.
                context.registerForVertexStateUpdates(source, states);
            }
        }

        LOG.info("Waiting for events (" + sourceInfoCount + " sources) ...");
        // synchronous event processing loop. Won't return until all events have
        // been processed.
        this.processEvents();
        this.prunePartitions();
        LOG.info("Ok to proceed.");
    }

    public BlockingQueue<Object> getQueue() {
        return queue;
    }

    private void clear() {
        sourceInfoMap.clear();
        sourceInfoCount = 0;
    }

    private void initialize() throws SerDeException {
        this.clear();
        Map<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
        // sources represent vertex names
        Set<String> sources = work.getEventSourceTableDescMap().keySet();

        sourcesWaitingForEvents.addAll(sources);

        for (String s : sources) {
            // Set to 0 to start with. This will be decremented for all columns for which events
            // are generated by this source - which is eventually used to determine number of expected
            // events for the source. #colums X #tasks
            numExpectedEventsPerSource.put(s, new MutableInt(0));
            numEventsSeenPerSource.put(s, new MutableInt(0));
            // Virtual relation generated by the reduce sync
            List<TableDesc> tables = work.getEventSourceTableDescMap().get(s);
            // Real column name - on which the operation is being performed
            List<String> columnNames = work.getEventSourceColumnNameMap().get(s);
            // Column type
            List<String> columnTypes = work.getEventSourceColumnTypeMap().get(s);
            // Expression for the operation. e.g. N^2 > 10
            List<ExprNodeDesc> partKeyExprs = work.getEventSourcePartKeyExprMap().get(s);
            // eventSourceTableDesc, eventSourceColumnName, evenSourcePartKeyExpr move in lock-step.
            // One entry is added to each at the same time

            Iterator<String> cit = columnNames.iterator();
            Iterator<String> typit = columnTypes.iterator();
            Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
            // A single source can process multiple columns, and will send an event for each of them.
            for (TableDesc t : tables) {
                numExpectedEventsPerSource.get(s).decrement();
                ++sourceInfoCount;
                String columnName = cit.next();
                String columnType = typit.next();
                ExprNodeDesc partKeyExpr = pit.next();
                SourceInfo si = createSourceInfo(t, partKeyExpr, columnName, columnType, jobConf);
                if (!sourceInfoMap.containsKey(s)) {
                    sourceInfoMap.put(s, new ArrayList<SourceInfo>());
                }
                List<SourceInfo> sis = sourceInfoMap.get(s);
                sis.add(si);

                // We could have multiple sources restrict the same column, need to take
                // the union of the values in that case.
                if (columnMap.containsKey(columnName)) {
                    // All Sources are initialized up front. Events from different sources will end up getting added to the same list.
                    // Pruning is disabled if either source sends in an event which causes pruning to be skipped
                    si.values = columnMap.get(columnName).values;
                    si.skipPruning = columnMap.get(columnName).skipPruning;
                }
                columnMap.put(columnName, si);
            }
        }
    }

    private void prunePartitions() throws HiveException {
        int expectedEvents = 0;
        for (Map.Entry<String, List<SourceInfo>> entry : this.sourceInfoMap.entrySet()) {
            String source = entry.getKey();
            for (SourceInfo si : entry.getValue()) {
                int taskNum = context.getVertexNumTasks(source);
                LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName);
                expectedEvents += taskNum;
                prunePartitionSingleSource(source, si);
            }
        }

        // sanity check. all tasks must submit events for us to succeed.
        if (expectedEvents != totalEventCount) {
            LOG.error("Expecting: " + expectedEvents + ", received: " + totalEventCount);
            throw new HiveException("Incorrect event count in dynamic partition pruning");
        }
    }

    @VisibleForTesting
    protected void prunePartitionSingleSource(String source, SourceInfo si) throws HiveException {

        if (si.skipPruning.get()) {
            // in this case we've determined that there's too much data
            // to prune dynamically.
            LOG.info("Skip pruning on " + source + ", column " + si.columnName);
            return;
        }

        Set<Object> values = si.values;
        String columnName = si.columnName;

        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("Pruning ");
            sb.append(columnName);
            sb.append(" with ");
            for (Object value : values) {
                sb.append(value == null ? null : value.toString());
                sb.append(", ");
            }
            LOG.debug(sb.toString());
        }

        ObjectInspector oi = PrimitiveObjectInspectorFactory
                .getPrimitiveWritableObjectInspector(TypeInfoFactory.getPrimitiveTypeInfo(si.columnType));

        Converter converter = ObjectInspectorConverters
                .getConverter(PrimitiveObjectInspectorFactory.javaStringObjectInspector, oi);

        StructObjectInspector soi = ObjectInspectorFactory.getStandardStructObjectInspector(
                Collections.singletonList(columnName), Collections.singletonList(oi));

        @SuppressWarnings("rawtypes")
        ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
        eval.initialize(soi);

        applyFilterToPartitions(converter, eval, columnName, values);
    }

    @SuppressWarnings("rawtypes")
    private void applyFilterToPartitions(Converter converter, ExprNodeEvaluator eval, String columnName,
            Set<Object> values) throws HiveException {

        Object[] row = new Object[1];

        Iterator<Path> it = work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            Path p = it.next();
            PartitionDesc desc = work.getPathToPartitionInfo().get(p);
            Map<String, String> spec = desc.getPartSpec();
            if (spec == null) {
                throw new IllegalStateException("No partition spec found in dynamic pruning");
            }

            String partValueString = spec.get(columnName);
            if (partValueString == null) {
                throw new IllegalStateException("Could not find partition value for column: " + columnName);
            }

            Object partValue = converter.convert(partValueString);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Converted partition value: " + partValue + " original (" + partValueString + ")");
            }

            row[0] = partValue;
            partValue = eval.evaluate(row);
            if (LOG.isDebugEnabled()) {
                LOG.debug("part key expr applied: " + partValue);
            }

            if (!values.contains(partValue)) {
                LOG.info("Pruning path: " + p);
                it.remove();
                // work.removePathToPartitionInfo(p);
                work.removePathToAlias(p);
            }
        }
    }

    @VisibleForTesting
    protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName,
            String columnType, JobConf jobConf) throws SerDeException {
        return new SourceInfo(t, partKeyExpr, columnName, columnType, jobConf);

    }

    @SuppressWarnings("deprecation")
    @VisibleForTesting
    static class SourceInfo {
        public final ExprNodeDesc partKey;
        public final Deserializer deserializer;
        public final StructObjectInspector soi;
        public final StructField field;
        public final ObjectInspector fieldInspector;
        /* List of partitions that are required - populated from processing each event */
        public Set<Object> values = new HashSet<Object>();
        /* Whether to skipPruning - depends on the payload from an event which may signal skip - if the event payload is too large */
        public AtomicBoolean skipPruning = new AtomicBoolean();
        public final String columnName;
        public final String columnType;

        @VisibleForTesting // Only used for testing.
        SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, String columnType, JobConf jobConf,
                Object forTesting) {
            this.partKey = partKey;
            this.columnName = columnName;
            this.columnType = columnType;
            this.deserializer = null;
            this.soi = null;
            this.field = null;
            this.fieldInspector = null;
        }

        public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, String columnType,
                JobConf jobConf) throws SerDeException {

            this.skipPruning.set(false);

            this.partKey = partKey;

            this.columnName = columnName;
            this.columnType = columnType;

            deserializer = ReflectionUtils.newInstance(table.getDeserializerClass(), null);
            deserializer.initialize(jobConf, table.getProperties());

            ObjectInspector inspector = deserializer.getObjectInspector();
            LOG.debug("Type of obj insp: " + inspector.getTypeName());

            soi = (StructObjectInspector) inspector;
            List<? extends StructField> fields = soi.getAllStructFieldRefs();
            if (fields.size() > 1) {
                LOG.error("expecting single field in input");
            }

            field = fields.get(0);

            fieldInspector = ObjectInspectorUtils.getStandardObjectInspector(field.getFieldObjectInspector());
        }
    }

    private void processEvents() throws SerDeException, IOException, InterruptedException {
        int eventCount = 0;

        while (true) {
            Object element = queue.take();

            if (element == endOfEvents) {
                // we're done processing events
                break;
            }

            InputInitializerEvent event = (InputInitializerEvent) element;

            LOG.info("Input event: " + event.getTargetInputName() + ", " + event.getTargetVertexName() + ", "
                    + (event.getUserPayload().limit() - event.getUserPayload().position()));
            processPayload(event.getUserPayload(), event.getSourceVertexName());
            eventCount += 1;
        }
        LOG.info("Received events: " + eventCount);
    }

    @SuppressWarnings("deprecation")
    @VisibleForTesting
    protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException {

        DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));
        try {
            String columnName = in.readUTF();

            LOG.info("Source of event: " + sourceName);

            List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
            if (infos == null) {
                throw new IllegalStateException("no source info for event source: " + sourceName);
            }

            SourceInfo info = null;
            for (SourceInfo si : infos) {
                if (columnName.equals(si.columnName)) {
                    info = si;
                    break;
                }
            }

            if (info == null) {
                throw new IllegalStateException("no source info for column: " + columnName);
            }

            if (info.skipPruning.get()) {
                // Marked as skipped previously. Don't bother processing the rest of the payload.
            } else {
                boolean skip = in.readBoolean();
                if (skip) {
                    info.skipPruning.set(true);
                } else {
                    while (payload.hasRemaining()) {
                        writable.readFields(in);

                        Object row = info.deserializer.deserialize(writable);

                        Object value = info.soi.getStructFieldData(row, info.field);
                        value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);

                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Adding: " + value + " to list of required partitions");
                        }
                        info.values.add(value);
                    }
                }
            }
        } finally {
            if (in != null) {
                in.close();
            }
        }
        return sourceName;
    }

    private static class ByteBufferBackedInputStream extends InputStream {

        ByteBuffer buf;

        public ByteBufferBackedInputStream(ByteBuffer buf) {
            this.buf = buf;
        }

        @Override
        public int read() throws IOException {
            if (!buf.hasRemaining()) {
                return -1;
            }
            return buf.get() & 0xFF;
        }

        @Override
        public int read(byte[] bytes, int off, int len) throws IOException {
            if (!buf.hasRemaining()) {
                return -1;
            }

            len = Math.min(len, buf.remaining());
            buf.get(bytes, off, len);
            return len;
        }
    }

    public void addEvent(InputInitializerEvent event) {
        synchronized (sourcesWaitingForEvents) {
            if (sourcesWaitingForEvents.contains(event.getSourceVertexName())) {
                ++totalEventCount;
                numEventsSeenPerSource.get(event.getSourceVertexName()).increment();
                if (!queue.offer(event)) {
                    throw new IllegalStateException("Queue full");
                }
                checkForSourceCompletion(event.getSourceVertexName());
            }
        }
    }

    public void processVertex(String name) {
        LOG.info("Vertex succeeded: " + name);
        synchronized (sourcesWaitingForEvents) {
            // Get a deterministic count of number of tasks for the vertex.
            MutableInt prevVal = numExpectedEventsPerSource.get(name);
            int prevValInt = prevVal.intValue();
            Preconditions.checkState(prevValInt < 0,
                    "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt);
            prevVal.setValue((-1) * prevValInt * context.getVertexNumTasks(name));
            checkForSourceCompletion(name);
        }
    }

    private void checkForSourceCompletion(String name) {
        int expectedEvents = numExpectedEventsPerSource.get(name).getValue();
        if (expectedEvents < 0) {
            // Expected events not updated yet - vertex SUCCESS notification not received.
            return;
        } else {
            int processedEvents = numEventsSeenPerSource.get(name).getValue();
            if (processedEvents == expectedEvents) {
                sourcesWaitingForEvents.remove(name);
                if (sourcesWaitingForEvents.isEmpty()) {
                    // we've got what we need; mark the queue
                    if (!queue.offer(endOfEvents)) {
                        throw new IllegalStateException("Queue full");
                    }
                } else {
                    LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " sources.");
                }
            } else if (processedEvents > expectedEvents) {
                throw new IllegalStateException("Received too many events for " + name + ", Expected="
                        + expectedEvents + ", Received=" + processedEvents);
            }
            return;
        }
    }
}